Introduce ShiftMixParseVarint Shuffle function

PiperOrigin-RevId: 517483241
pull/12269/head
Martijn Vels 2 years ago committed by Copybara-Service
parent a7846178ef
commit f1237863de
  1. 32
      src/google/protobuf/BUILD.bazel
  2. 169
      src/google/protobuf/generated_message_tctable_lite.cc
  3. 188
      src/google/protobuf/varint_shuffle.h
  4. 331
      src/google/protobuf/varint_shuffle_test.cc

@ -183,6 +183,33 @@ cc_library(
],
)
cc_library(
name = "varint_shuffle",
hdrs = ["varint_shuffle.h"],
copts = COPTS,
include_prefix = "google/protobuf",
visibility = [
"//:__subpackages__",
"//src/google/protobuf:__subpackages__",
],
deps = [
":port_def",
"@com_google_absl//absl/log:absl_check",
],
)
cc_test(
name = "varint_shuffle_test",
srcs = ["varint_shuffle_test.cc"],
deps = [
":port_def",
":varint_shuffle",
"@com_google_absl//absl/log:absl_check",
"@com_google_googletest//:gtest",
"@com_google_googletest//:gtest_main",
],
)
cc_library(
name = "arena_align",
srcs = ["arena_align.cc"],
@ -361,6 +388,7 @@ cc_library(
":arena_align",
":arena_config",
":string_block",
":varint_shuffle",
"//src/google/protobuf/io",
"//src/google/protobuf/stubs:lite",
"@com_google_absl//absl/container:btree",
@ -487,12 +515,12 @@ cc_library(
copts = COPTS,
include_prefix = "google/protobuf",
linkopts = LINK_OPTS,
visibility = ["//:__subpackages__"],
deps = [
":protobuf_nowkt",
":port_def",
":protobuf_nowkt",
"@com_google_absl//absl/strings",
],
visibility = ["//:__subpackages__"]
)
filegroup(

@ -41,6 +41,7 @@
#include "google/protobuf/map.h"
#include "google/protobuf/message_lite.h"
#include "google/protobuf/parse_context.h"
#include "google/protobuf/varint_shuffle.h"
#include "google/protobuf/wire_format_lite.h"
#include "utf8_validity.h"
@ -697,136 +698,6 @@ PROTOBUF_NOINLINE const char* TcParser::FastF64P2(PROTOBUF_TC_PARAM_DECL) {
namespace {
// Shift "byte" left by n * 7 bits, filling vacated bits with ones.
template <int n>
inline PROTOBUF_ALWAYS_INLINE int64_t shift_left_fill_with_ones(uint64_t byte,
uint64_t ones) {
return static_cast<int64_t>((byte << (n * 7)) | (ones >> (64 - (n * 7))));
}
// Shift "byte" left by n * 7 bits, filling vacated bits with ones, and
// put the new value in res. Return whether the result was negative.
template <int n>
inline PROTOBUF_ALWAYS_INLINE bool shift_left_fill_with_ones_was_negative(
uint64_t byte, uint64_t ones, int64_t& res) {
#if defined(__GCC_ASM_FLAG_OUTPUTS__) && defined(__x86_64__)
// For the first two rounds (up to 2 varint bytes), micro benchmarks show a
// substantial improvement from capturing the sign from the condition code
// register on x86-64.
bool sign_bit;
asm("shldq %3, %2, %1"
: "=@ccs"(sign_bit), "+r"(byte)
: "r"(ones), "i"(n * 7));
res = static_cast<int64_t>(byte);
return sign_bit;
#else
// Generic fallback:
res = shift_left_fill_with_ones<n>(byte, ones);
return res < 0;
#endif
}
template <class VarintType>
inline PROTOBUF_ALWAYS_INLINE std::pair<const char*, VarintType>
ParseFallbackPair(const char* p, int64_t res1) {
constexpr bool kIs64BitVarint = std::is_same<VarintType, uint64_t>::value;
constexpr bool kIs32BitVarint = std::is_same<VarintType, uint32_t>::value;
static_assert(kIs64BitVarint || kIs32BitVarint,
"Only 32 or 64 bit varints are supported");
auto ptr = reinterpret_cast<const int8_t*>(p);
// The algorithm relies on sign extension for each byte to set all high bits
// when the varint continues. It also relies on asserting all of the lower
// bits for each successive byte read. This allows the result to be aggregated
// using a bitwise AND. For example:
//
// 8 1 64 57 ... 24 17 16 9 8 1
// ptr[0] = 1aaa aaaa ; res1 = 1111 1111 ... 1111 1111 1111 1111 1aaa aaaa
// ptr[1] = 1bbb bbbb ; res2 = 1111 1111 ... 1111 1111 11bb bbbb b111 1111
// ptr[2] = 0ccc cccc ; res3 = 0000 0000 ... 000c cccc cc11 1111 1111 1111
// ---------------------------------------------
// res1 & res2 & res3 = 0000 0000 ... 000c cccc ccbb bbbb baaa aaaa
//
// On x86-64, a shld from a single register filled with enough 1s in the high
// bits can accomplish all this in one instruction. It so happens that res1
// has 57 high bits of ones, which is enough for the largest shift done.
//
// Just as importantly, by keeping results in res1, res2, and res3, we take
// advantage of the superscalar abilities of the CPU.
ABSL_DCHECK_EQ(res1 >> 7, -1);
uint64_t ones = res1; // save the high 1 bits from res1 (input to SHLD)
int64_t res2, res3; // accumulated result chunks
if (!shift_left_fill_with_ones_was_negative<1>(ptr[1], ones, res2))
goto done2;
if (!shift_left_fill_with_ones_was_negative<2>(ptr[2], ones, res3))
goto done3;
// For the remainder of the chunks, check the sign of the AND result.
res2 &= shift_left_fill_with_ones<3>(ptr[3], ones);
if (res2 >= 0) goto done4;
res1 &= shift_left_fill_with_ones<4>(ptr[4], ones);
if (res1 >= 0) goto done5;
if (kIs64BitVarint) {
res2 &= shift_left_fill_with_ones<5>(ptr[5], ones);
if (res2 >= 0) goto done6;
res3 &= shift_left_fill_with_ones<6>(ptr[6], ones);
if (res3 >= 0) goto done7;
res1 &= shift_left_fill_with_ones<7>(ptr[7], ones);
if (res1 >= 0) goto done8;
res3 &= shift_left_fill_with_ones<8>(ptr[8], ones);
if (res3 >= 0) goto done9;
} else if (kIs32BitVarint) {
if (PROTOBUF_PREDICT_TRUE(!(ptr[5] & 0x80))) goto done6;
if (PROTOBUF_PREDICT_TRUE(!(ptr[6] & 0x80))) goto done7;
if (PROTOBUF_PREDICT_TRUE(!(ptr[7] & 0x80))) goto done8;
if (PROTOBUF_PREDICT_TRUE(!(ptr[8] & 0x80))) goto done9;
}
// For valid 64bit varints, the 10th byte/ptr[9] should be exactly 1. In this
// case, the continuation bit of ptr[8] already set the top bit of res3
// correctly, so all we have to do is check that the expected case is true.
if (PROTOBUF_PREDICT_TRUE(kIs64BitVarint && ptr[9] == 1)) goto done10;
if (PROTOBUF_PREDICT_FALSE(ptr[9] & 0x80)) {
// If the continue bit is set, it is an unterminated varint.
return {nullptr, 0};
}
// A zero value of the first bit of the 10th byte represents an
// over-serialized varint. This case should not happen, but if does (say, due
// to a nonconforming serializer), deassert the continuation bit that came
// from ptr[8].
if (kIs64BitVarint && (ptr[9] & 1) == 0) {
#if defined(__GCC_ASM_FLAG_OUTPUTS__) && defined(__x86_64__)
// Use a small instruction since this is an uncommon code path.
asm("btcq $63,%0" : "+r"(res3));
#else
res3 ^= static_cast<uint64_t>(1) << 63;
#endif
}
goto done10;
done2:
return {p + 2, res1 & res2};
done3:
return {p + 3, res1 & res2 & res3};
done4:
return {p + 4, res1 & res2 & res3};
done5:
return {p + 5, res1 & res2 & res3};
done6:
return {p + 6, res1 & res2 & res3};
done7:
return {p + 7, res1 & res2 & res3};
done8:
return {p + 8, res1 & res2 & res3};
done9:
return {p + 9, res1 & res2 & res3};
done10:
return {p + 10, res1 & res2 & res3};
}
template <typename Type>
inline PROTOBUF_ALWAYS_INLINE const char* ParseVarint(const char* p,
Type* value) {
@ -841,17 +712,10 @@ inline PROTOBUF_ALWAYS_INLINE const char* ParseVarint(const char* p,
}
return p;
#endif
int64_t byte = static_cast<int8_t>(*p);
if (PROTOBUF_PREDICT_TRUE(byte >= 0)) {
*value = byte;
return p + 1;
} else {
auto tmp = ParseFallbackPair<std::make_unsigned_t<Type>>(p, byte);
if (PROTOBUF_PREDICT_TRUE(tmp.first)) {
*value = static_cast<Type>(tmp.second);
}
return tmp.first;
}
int64_t res;
p = ShiftMixParseVarint<Type>(p, res);
*value = res;
return p;
}
// This overload is specifically for handling bool, because bools have very
@ -979,7 +843,7 @@ PROTOBUF_NOINLINE const char* TcParser::SingularVarBigint(
asm("" : "+m"(spill));
#endif
FieldType tmp;
uint64_t tmp;
PROTOBUF_ASSUME(static_cast<int8_t>(*ptr) < 0);
ptr = ParseVarint(ptr, &tmp);
@ -1000,29 +864,16 @@ template <typename FieldType>
PROTOBUF_ALWAYS_INLINE const char* TcParser::FastVarintS1(
PROTOBUF_TC_PARAM_DECL) {
using TagType = uint8_t;
// super-early success test...
if (PROTOBUF_PREDICT_TRUE(((data.data) & 0x80FF) == 0)) {
ptr += sizeof(TagType); // Consume tag
hasbits |= (uint64_t{1} << data.hasbit_idx());
uint8_t value = data.data >> 8;
RefAt<FieldType>(msg, data.offset()) = value;
ptr += 1;
PROTOBUF_MUSTTAIL return ToTagDispatch(PROTOBUF_TC_PARAM_NO_DATA_PASS);
}
if (PROTOBUF_PREDICT_FALSE(data.coded_tag<TagType>() != 0)) {
PROTOBUF_MUSTTAIL return MiniParse(PROTOBUF_TC_PARAM_NO_DATA_PASS);
}
ptr += sizeof(TagType); // Consume tag
hasbits |= (uint64_t{1} << data.hasbit_idx());
auto tmp =
ParseFallbackPair<FieldType>(ptr, static_cast<int8_t>(data.data >> 8));
ptr = tmp.first;
int64_t res;
ptr = ShiftMixParseVarint<FieldType>(ptr + sizeof(TagType), res);
if (PROTOBUF_PREDICT_FALSE(ptr == nullptr)) {
PROTOBUF_MUSTTAIL return Error(PROTOBUF_TC_PARAM_NO_DATA_PASS);
}
RefAt<FieldType>(msg, data.offset()) = tmp.second;
hasbits |= (uint64_t{1} << data.hasbit_idx());
RefAt<FieldType>(msg, data.offset()) = res;
PROTOBUF_MUSTTAIL return ToTagDispatch(PROTOBUF_TC_PARAM_NO_DATA_PASS);
}

@ -0,0 +1,188 @@
// Protocol Buffers - Google's data interchange format
// Copyright 2023 Google Inc. All rights reserved.
// https://developers.google.com/protocol-buffers/
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are
// met:
//
// * Redistributions of source code must retain the above copyright
// notice, this list of conditions and the following disclaimer.
// * Redistributions in binary form must reproduce the above
// copyright notice, this list of conditions and the following disclaimer
// in the documentation and/or other materials provided with the
// distribution.
// * Neither the name of Google Inc. nor the names of its
// contributors may be used to endorse or promote products derived from
// this software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#ifndef GOOGLE_PROTOBUF_VARINT_SHUFFLE_H__
#define GOOGLE_PROTOBUF_VARINT_SHUFFLE_H__
#include <cassert>
#include <cstddef>
#include <cstdint>
#include <type_traits>
#include <utility>
#include "absl/log/absl_check.h" // for PROTOBUF_ASSUME
// Must be included last.
#include "google/protobuf/port_def.inc"
namespace google {
namespace protobuf {
namespace internal {
// Shifts "byte" left by n * 7 bits, filling vacated bits from `ones`.
template <int n>
inline PROTOBUF_ALWAYS_INLINE int64_t VarintShlByte(int8_t byte, int64_t ones) {
return static_cast<int64_t>((static_cast<uint64_t>(byte) << n * 7) |
(static_cast<uint64_t>(ones) >> (64 - n * 7)));
}
// Shifts "byte" left by n * 7 bits, filling vacated bits from `ones` and
// bitwise ANDs the resulting value into the input/output `res` parameter.
// Returns true if the result was not negative.
template <int n>
inline PROTOBUF_ALWAYS_INLINE bool VarintShlAnd(int8_t byte, int64_t ones,
int64_t& res) {
res &= VarintShlByte<n>(byte, ones);
return res >= 0;
}
// Shifts `byte` left by n * 7 bits, filling vacated bits with ones, and
// puts the new value in the output only parameter `res`.
// Returns true if the result was not negative.
template <int n>
inline PROTOBUF_ALWAYS_INLINE bool VarintShl(int8_t byte, int64_t ones,
int64_t& res) {
res = VarintShlByte<n>(byte, ones);
return res >= 0;
}
template <typename VarintType, int limit = 10>
inline PROTOBUF_ALWAYS_INLINE const char* ShiftMixParseVarint(const char* p,
int64_t& res1) {
using Signed = std::make_signed_t<VarintType>;
constexpr bool kIs64BitVarint = std::is_same<Signed, int64_t>::value;
constexpr bool kIs32BitVarint = std::is_same<Signed, int32_t>::value;
static_assert(kIs64BitVarint || kIs32BitVarint, "");
// The algorithm relies on sign extension for each byte to set all high bits
// when the varint continues. It also relies on asserting all of the lower
// bits for each successive byte read. This allows the result to be aggregated
// using a bitwise AND. For example:
//
// 8 1 64 57 ... 24 17 16 9 8 1
// ptr[0] = 1aaa aaaa ; res1 = 1111 1111 ... 1111 1111 1111 1111 1aaa aaaa
// ptr[1] = 1bbb bbbb ; res2 = 1111 1111 ... 1111 1111 11bb bbbb b111 1111
// ptr[2] = 0ccc cccc ; res3 = 0000 0000 ... 000c cccc cc11 1111 1111 1111
// ---------------------------------------------
// res1 & res2 & res3 = 0000 0000 ... 000c cccc ccbb bbbb baaa aaaa
//
// On x86-64, a shld from a single register filled with enough 1s in the high
// bits can accomplish all this in one instruction. It so happens that res1
// has 57 high bits of ones, which is enough for the largest shift done.
//
// Just as importantly, by keeping results in res1, res2, and res3, we take
// advantage of the superscalar abilities of the CPU.
const auto next = [&p] { return static_cast<const int8_t>(*p++); };
const auto last = [&p] { return static_cast<const int8_t>(p[-1]); };
int64_t res2, res3; // accumulated result chunks
res1 = next();
if (PROTOBUF_PREDICT_TRUE(res1 >= 0)) return p;
if (limit <= 1) goto limit0;
// Densify all ops with explicit FALSE predictions from here on, except that
// we predict length = 5 as a common length for fields like timestamp.
if (PROTOBUF_PREDICT_FALSE(VarintShl<1>(next(), res1, res2))) goto done1;
if (limit <= 2) goto limit1;
if (PROTOBUF_PREDICT_FALSE(VarintShl<2>(next(), res1, res3))) goto done2;
if (limit <= 3) goto limit2;
if (PROTOBUF_PREDICT_FALSE(VarintShlAnd<3>(next(), res1, res2))) goto done2;
if (limit <= 4) goto limit2;
if (PROTOBUF_PREDICT_TRUE(VarintShlAnd<4>(next(), res1, res3))) goto done2;
if (limit <= 5) goto limit2;
if (kIs64BitVarint) {
if (PROTOBUF_PREDICT_FALSE(VarintShlAnd<5>(next(), res1, res2))) goto done2;
if (limit <= 6) goto limit2;
if (PROTOBUF_PREDICT_FALSE(VarintShlAnd<6>(next(), res1, res3))) goto done2;
if (limit <= 7) goto limit2;
if (PROTOBUF_PREDICT_FALSE(VarintShlAnd<7>(next(), res1, res2))) goto done2;
if (limit <= 8) goto limit2;
if (PROTOBUF_PREDICT_FALSE(VarintShlAnd<8>(next(), res1, res3))) goto done2;
if (limit <= 9) goto limit2;
} else {
// An overlong int32 is expected to span the full 10 bytes
if (PROTOBUF_PREDICT_FALSE(!(next() & 0x80))) goto done2;
if (limit <= 6) goto limit2;
if (PROTOBUF_PREDICT_FALSE(!(next() & 0x80))) goto done2;
if (limit <= 7) goto limit2;
if (PROTOBUF_PREDICT_FALSE(!(next() & 0x80))) goto done2;
if (limit <= 8) goto limit2;
if (PROTOBUF_PREDICT_FALSE(!(next() & 0x80))) goto done2;
if (limit <= 9) goto limit2;
}
// For valid 64bit varints, the 10th byte/ptr[9] should be exactly 1. In this
// case, the continuation bit of ptr[8] already set the top bit of res3
// correctly, so all we have to do is check that the expected case is true.
if (PROTOBUF_PREDICT_TRUE(next() == 1)) goto done2;
if (PROTOBUF_PREDICT_FALSE(last() & 0x80)) {
// If the continue bit is set, it is an unterminated varint.
return nullptr;
}
// A zero value of the first bit of the 10th byte represents an
// over-serialized varint. This case should not happen, but if does (say, due
// to a nonconforming serializer), deassert the continuation bit that came
// from ptr[8].
if (kIs64BitVarint && (last() & 1) == 0) {
static constexpr int bits = 64 - 1;
#if defined(__GCC_ASM_FLAG_OUTPUTS__) && defined(__x86_64__)
// Use a small instruction since this is an uncommon code path.
asm("btc %[bits], %[res3]" : [res3] "+r"(res3) : [bits] "i"(bits));
#else
res3 ^= int64_t{1} << bits;
#endif
}
done2:
res2 &= res3;
done1:
res1 &= res2;
PROTOBUF_ASSUME(p != nullptr);
return p;
limit2:
res2 &= res3;
limit1:
res1 &= res2;
limit0:
PROTOBUF_ASSUME(p != nullptr);
PROTOBUF_ASSUME(res1 < 0);
return p;
}
} // namespace internal
} // namespace protobuf
} // namespace google
#include "google/protobuf/port_undef.inc"
#endif // GOOGLE_PROTOBUF_VARINT_SHUFFLE_H__

@ -0,0 +1,331 @@
// Protocol Buffers - Google's data interchange format
// Copyright 2023 Google Inc. All rights reserved.
// https://developers.google.com/protocol-buffers/
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are
// met:
//
// * Redistributions of source code must retain the above copyright
// notice, this list of conditions and the following disclaimer.
// * Redistributions in binary form must reproduce the above
// copyright notice, this list of conditions and the following disclaimer
// in the documentation and/or other materials provided with the
// distribution.
// * Neither the name of Google Inc. nor the names of its
// contributors may be used to endorse or promote products derived from
// this software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#include "google/protobuf/varint_shuffle.h"
#include <string>
#include <vector>
#include <gmock/gmock.h>
#include <gtest/gtest.h>
// Must be included last.
#include "google/protobuf/port_def.inc"
using testing::Eq;
using testing::IsNull;
using testing::NotNull;
using testing::Range;
using testing::TestWithParam;
namespace google {
namespace protobuf {
namespace internal {
namespace {
uint64_t ToInt64(char c) { return static_cast<uint8_t>(c); }
uint32_t ToInt32(char c) { return static_cast<uint8_t>(c); }
int32_t Shl(uint32_t v, int bits) { return static_cast<int32_t>(v << bits); }
int32_t Shl(int32_t v, int bits) { return Shl(static_cast<uint32_t>(v), bits); }
int64_t Shl(uint64_t v, int bits) { return static_cast<int64_t>(v << bits); }
int64_t Shl(int64_t v, int bits) { return Shl(static_cast<uint64_t>(v), bits); }
int NaiveParse(const char* p, int32_t& res) {
int len = 0;
auto r = ToInt32(*p);
while (*p++ & 0x80) {
if (++len == 10) return 11;
if (len < 5) r += Shl(ToInt32(*p) - 1, len * 7);
}
res = r;
return ++len;
}
// A naive, easy to verify implementation for test purposes.
int NaiveParse(const char* p, int64_t& res) {
int len = 0;
auto r = ToInt64(*p);
while (*p++ & 0x80) {
if (++len == 10) return 11;
r += Shl(ToInt64(*p) - 1, len * 7);
}
res = r;
return ++len;
}
// A naive, easy to verify implementation for test purposes.
int NaiveSerialize(char* p, uint64_t value) {
int n = 0;
while (value > 127) {
p[n++] = 0x80 | static_cast<char>(value);
value >>= 7;
}
p[n++] = static_cast<char>(value);
return n;
}
class ShiftMixParseVarint32Test : public TestWithParam<int> {
public:
int length() const { return GetParam(); }
};
class ShiftMixParseVarint64Test : public TestWithParam<int> {
public:
int length() const { return GetParam(); }
};
INSTANTIATE_TEST_SUITE_P(Default, ShiftMixParseVarint32Test, Range(1, 11));
INSTANTIATE_TEST_SUITE_P(Default, ShiftMixParseVarint64Test, Range(1, 11));
template <int limit = 10>
const char* Parse(const char* data, int32_t& res) {
int64_t res64;
const char* ret = ShiftMixParseVarint<int32_t, limit>(data, res64);
res = res64;
return ret;
}
template <int limit = 10>
const char* Parse(const char* data, int64_t& res) {
return ShiftMixParseVarint<int64_t, limit>(data, res);
}
template <int limit = 0>
const char* ParseWithLimit(int rtlimit, const char* data, int32_t& res) {
if (rtlimit > limit) return ParseWithLimit<limit + 1>(rtlimit, data, res);
return Parse<limit>(data, res);
}
template <int limit = 0>
const char* ParseWithLimit(int rtlimit, const char* data, int64_t& res) {
if (rtlimit > limit) return ParseWithLimit<limit + 1>(rtlimit, data, res);
return Parse<limit>(data, res);
}
template <>
const char* ParseWithLimit<10>(int rtlimit, const char* data, int32_t& res) {
return Parse<10>(data, res);
}
template <>
const char* ParseWithLimit<10>(int rtlimit, const char* data, int64_t& res) {
return Parse<10>(data, res);
}
template <typename T>
void TestAllLengths(int len) {
std::vector<char> bytes;
for (int i = 1; i < len; ++i) {
bytes.push_back(static_cast<char>(0xC1 + (i << 1)));
}
bytes.push_back('\x01');
const char* data = bytes.data();
T expected;
ASSERT_THAT(NaiveParse(data, expected), Eq(len));
T result;
const char* p = Parse(data, result);
ASSERT_THAT(p, NotNull());
ASSERT_THAT(p - data, Eq(len));
ASSERT_THAT(result, Eq(expected));
}
TEST_P(ShiftMixParseVarint32Test, AllLengths) {
TestAllLengths<int32_t>(length());
}
TEST_P(ShiftMixParseVarint64Test, AllLengths) {
TestAllLengths<int64_t>(length());
}
template <typename T>
void TestNonCanonicalValue(int len) {
char data[] = {'\xc3', '\xc5', '\xc7', '\xc9', '\xcb',
'\xcd', '\xcf', '\xd1', '\xd3', '\x7E'};
if (len < 10) data[len++] = 0;
T expected;
ASSERT_THAT(NaiveParse(data, expected), Eq(len));
T result;
const char* p = Parse(data, result);
ASSERT_THAT(p, NotNull());
ASSERT_THAT(p - data, Eq(len));
ASSERT_THAT(result, Eq(expected));
}
TEST_P(ShiftMixParseVarint32Test, NonCanonicalValue) {
TestNonCanonicalValue<int32_t>(length());
}
TEST_P(ShiftMixParseVarint64Test, NonCanonicalValue) {
TestNonCanonicalValue<int64_t>(length());
}
template <typename T>
void TestNonCanonicalZero(int len) {
char data[] = {'\x80', '\x80', '\x80', '\x80', '\x80',
'\x80', '\x80', '\x80', '\x80', '\x7E'};
if (len < 10) data[len++] = 0;
T expected;
ASSERT_THAT(NaiveParse(data, expected), Eq(len));
ASSERT_THAT(expected, Eq(0));
T result;
const char* p = Parse(data, result);
ASSERT_THAT(p, NotNull());
ASSERT_THAT(p - data, Eq(len));
ASSERT_THAT(result, Eq(expected));
} // namespace
TEST_P(ShiftMixParseVarint32Test, NonCanonicalZero) {
TestNonCanonicalZero<int32_t>(length());
}
TEST_P(ShiftMixParseVarint64Test, NonCanonicalZero) {
TestNonCanonicalZero<int64_t>(length());
}
TEST_P(ShiftMixParseVarint32Test, HittingLimit) {
const int limit = length();
int32_t res = 0x94939291L;
char data[10];
int serialized_len = NaiveSerialize(data, res);
ASSERT_THAT(serialized_len, Eq(10));
int32_t result;
const char* p = ParseWithLimit(limit, data, result);
ASSERT_THAT(p, testing::NotNull());
ASSERT_THAT(p - data, Eq(limit));
if (limit < 5) {
res |= Shl(int32_t{-1}, limit * 7);
}
ASSERT_THAT(result, Eq(res));
}
TEST_P(ShiftMixParseVarint64Test, HittingLimit) {
const int limit = length();
int64_t res = 0x9897969594939291LL;
char data[10];
int serialized_len = NaiveSerialize(data, res);
ASSERT_THAT(serialized_len, Eq(10));
int64_t result;
const char* p = ParseWithLimit(limit, data, result);
ASSERT_THAT(p, testing::NotNull());
ASSERT_THAT(p - data, Eq(limit));
if (limit != 10) {
res |= Shl(int64_t{-1}, limit * 7);
}
ASSERT_THAT(result, Eq(res));
}
TEST_P(ShiftMixParseVarint32Test, AtOrBelowLimit) {
const int limit = length();
if (limit > 5) GTEST_SKIP() << "N/A";
int32_t res = 0x94939291ULL >> (35 - 7 * limit);
char data[10];
int serialized_len = NaiveSerialize(data, res);
ASSERT_THAT(serialized_len, Eq(limit == 5 ? 10 : limit));
int32_t result;
const char* p = ParseWithLimit(limit, data, result);
ASSERT_THAT(p, testing::NotNull());
ASSERT_THAT(p - data, Eq(limit));
ASSERT_THAT(result, Eq(res));
}
TEST_P(ShiftMixParseVarint64Test, AtOrBelowLimit) {
const int limit = length();
int64_t res = 0x9897969594939291ULL >> (70 - 7 * limit);
char data[10];
int serialized_len = NaiveSerialize(data, res);
ASSERT_THAT(serialized_len, Eq(limit));
int64_t result;
const char* p = ParseWithLimit(limit, data, result);
ASSERT_THAT(p, testing::NotNull());
ASSERT_THAT(p - data, Eq(limit));
ASSERT_THAT(result, Eq(res));
}
TEST(ShiftMixParseVarint32Test, OverLong) {
char data[] = {'\xc3', '\xc5', '\xc7', '\xc9', '\xcb',
'\xcd', '\xcf', '\xd1', '\xd3', '\x81'};
int32_t result;
const char* p = Parse(data, result);
ASSERT_THAT(p, IsNull());
}
TEST(ShiftMixParseVarint64Test, OverLong) {
char data[] = {'\xc3', '\xc5', '\xc7', '\xc9', '\xcb',
'\xcd', '\xcf', '\xd1', '\xd3', '\x81'};
int64_t result;
const char* p = Parse(data, result);
ASSERT_THAT(p, IsNull());
}
TEST(ShiftMixParseVarint32Test, DroppingOverlongBits) {
char data[] = {'\xc3', '\xc5', '\xc7', '\xc9', '\x7F'};
int32_t expected;
ASSERT_THAT(NaiveParse(data, expected), Eq(5));
int32_t result;
const char* p = Parse(data, result);
ASSERT_THAT(p, NotNull());
ASSERT_THAT(p - data, Eq(5));
ASSERT_THAT(result, Eq(expected));
}
TEST(ShiftMixParseVarint64Test, DroppingOverlongBits) {
char data[] = {'\xc3', '\xc5', '\xc7', '\xc9', '\xcb',
'\xcd', '\xcf', '\xd1', '\xd3', '\x7F'};
int64_t expected;
ASSERT_THAT(NaiveParse(data, expected), Eq(10));
int64_t result;
const char* p = Parse(data, result);
ASSERT_THAT(p, NotNull());
ASSERT_THAT(p - data, Eq(10));
ASSERT_THAT(result, Eq(expected));
}
} // namespace
} // namespace internal
} // namespace protobuf
} // namespace google
Loading…
Cancel
Save