Add a consistency check to the parser to verify that has bits are valid on

parse failure, and fix the issues found by tests.
Keeping the invariant can help performance and future changes.

PiperOrigin-RevId: 689799465
pull/18904/head
Protobuf Team Bot 4 months ago committed by Copybara-Service
parent 5aff431888
commit b5fca3e1b5
  1. 3
      src/google/protobuf/BUILD.bazel
  2. 8
      src/google/protobuf/generated_message_tctable_impl.h
  3. 175
      src/google/protobuf/generated_message_tctable_lite.cc
  4. 39
      src/google/protobuf/generated_message_tctable_lite_test.cc

@ -1583,9 +1583,12 @@ cc_test(
}),
deps = [
":cc_test_protos",
":descriptor_visitor",
":port",
":protobuf",
":protobuf_lite",
"//src/google/protobuf/io",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/log:absl_check",
"@com_google_absl//absl/log:absl_log",
"@com_google_absl//absl/strings",

@ -822,6 +822,9 @@ class PROTOBUF_EXPORT TcParser final {
};
}
static void VerifyHasBitConsistency(const MessageLite* msg,
const TcParseTableBase* table);
private:
// Optimized small tag varint parser for int32/int64
template <typename FieldType>
@ -1027,6 +1030,8 @@ class PROTOBUF_EXPORT TcParser final {
static absl::string_view MessageName(const TcParseTableBase* table);
static absl::string_view FieldName(const TcParseTableBase* table,
const TcParseTableBase::FieldEntry*);
static int FieldNumber(const TcParseTableBase* table,
const TcParseTableBase::FieldEntry*);
static bool ChangeOneof(const TcParseTableBase* table,
const TcParseTableBase::FieldEntry& entry,
uint32_t field_num, ParseContext* ctx,
@ -1152,6 +1157,9 @@ inline PROTOBUF_ALWAYS_INLINE const char* TcParser::ParseLoop(
if (ABSL_PREDICT_FALSE(table->has_post_loop_handler)) {
return table->post_loop_handler(msg, ptr, ctx);
}
if (ABSL_PREDICT_FALSE(PerformDebugChecks() && ptr == nullptr)) {
VerifyHasBitConsistency(msg, table);
}
return ptr;
}

@ -20,7 +20,9 @@
#include "absl/log/absl_log.h"
#include "absl/numeric/bits.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/str_format.h"
#include "absl/strings/string_view.h"
#include "absl/types/optional.h"
#include "google/protobuf/arenastring.h"
#include "google/protobuf/generated_enum_util.h"
#include "google/protobuf/generated_message_tctable_decl.h"
@ -73,6 +75,120 @@ const char* TcParser::GenericFallbackLite(PROTOBUF_TC_PARAM_DECL) {
PROTOBUF_TC_PARAM_PASS);
}
namespace {
bool ReadHas(const FieldEntry& entry, const MessageLite* msg) {
auto has_idx = static_cast<uint32_t>(entry.has_idx);
const auto& hasblock = TcParser::RefAt<const uint32_t>(msg, has_idx / 32 * 4);
return (hasblock & (uint32_t{1} << (has_idx % 32))) != 0;
}
} // namespace
void TcParser::VerifyHasBitConsistency(const MessageLite* msg,
const TcParseTableBase* table) {
namespace fl = internal::field_layout;
if (table->has_bits_offset == 0) {
// Nothing to check
return;
}
for (const auto& entry : table->field_entries()) {
const auto print_error = [&] {
return absl::StrFormat("Type=%s Field=%d\n", msg->GetTypeName(),
FieldNumber(table, &entry));
};
if ((entry.type_card & fl::kFcMask) != fl::kFcOptional) return;
const bool has_bit = ReadHas(entry, msg);
const void* base = msg;
const void* default_base = table->default_instance();
if ((entry.type_card & field_layout::kSplitMask) ==
field_layout::kSplitTrue) {
const size_t offset = table->field_aux(kSplitOffsetAuxIdx)->offset;
base = TcParser::RefAt<const void*>(base, offset);
default_base = TcParser::RefAt<const void*>(default_base, offset);
}
switch (entry.type_card & fl::kFkMask) {
case fl::kFkVarint:
case fl::kFkFixed:
// Numerics can have any value when the has bit is on.
if (has_bit) return;
switch (entry.type_card & fl::kRepMask) {
case fl::kRep8Bits:
ABSL_CHECK_EQ(RefAt<bool>(base, entry.offset),
RefAt<bool>(default_base, entry.offset))
<< print_error();
break;
case fl::kRep32Bits:
ABSL_CHECK_EQ(RefAt<uint32_t>(base, entry.offset),
RefAt<uint32_t>(default_base, entry.offset))
<< print_error();
break;
case fl::kRep64Bits:
ABSL_CHECK_EQ(RefAt<uint64_t>(base, entry.offset),
RefAt<uint64_t>(default_base, entry.offset))
<< print_error();
break;
}
break;
case fl::kFkString:
switch (entry.type_card & fl::kRepMask) {
case field_layout::kRepAString:
if (has_bit) {
// Must not point to the default.
ABSL_CHECK(!RefAt<ArenaStringPtr>(base, entry.offset).IsDefault())
<< print_error();
} else {
// We should technically check that the value matches the default
// value of the field, but the prototype does not actually contain
// this value. Non-empty defaults are loaded on access.
}
break;
case field_layout::kRepCord:
if (!has_bit) {
// If the has bit is off, it must match the default.
ABSL_CHECK_EQ(RefAt<absl::Cord>(base, entry.offset),
RefAt<absl::Cord>(default_base, entry.offset))
<< print_error();
}
break;
case field_layout::kRepIString:
if (!has_bit) {
// If the has bit is off, it must match the default.
ABSL_CHECK_EQ(
RefAt<InlinedStringField>(base, entry.offset).Get(),
RefAt<InlinedStringField>(default_base, entry.offset).Get())
<< print_error();
}
break;
case field_layout::kRepSString:
Unreachable();
}
break;
case fl::kFkMessage:
switch (entry.type_card & fl::kRepMask) {
case fl::kRepMessage:
case fl::kRepGroup:
if (has_bit) {
ABSL_CHECK(RefAt<const MessageLite*>(base, entry.offset) !=
nullptr)
<< print_error();
} else {
// An off has_bit does not imply a null pointer.
// We might have a previous instance that we cached.
}
break;
default:
Unreachable();
}
break;
default:
// All other types are not `optional`.
Unreachable();
}
}
}
//////////////////////////////////////////////////////////////////////////////
// Core fast parsing implementation:
//////////////////////////////////////////////////////////////////////////////
@ -226,6 +342,46 @@ absl::string_view TcParser::FieldName(const TcParseTableBase* table,
field_index + 1);
}
int TcParser::FieldNumber(const TcParseTableBase* table,
const TcParseTableBase::FieldEntry* entry) {
// The data structure was not designed to be queried in this direction, so
// we have to do a linear search over the entries to see which one matches
// while keeping track of the field number.
// But it is fine because we are only using this for debug check messages.
size_t need_to_skip = entry - table->field_entries_begin();
const auto visit_bitmap = [&](uint32_t field_bitmap,
int base_field_number) -> absl::optional<int> {
for (; field_bitmap != 0; field_bitmap &= field_bitmap - 1) {
if (need_to_skip == 0) {
return absl::countr_zero(field_bitmap) + base_field_number;
}
--need_to_skip;
}
return absl::nullopt;
};
if (auto number = visit_bitmap(~table->skipmap32, 1)) {
return *number;
}
for (const uint16_t* lookup_table = table->field_lookup_begin();
lookup_table[0] != 0xFFFF || lookup_table[1] != 0xFFFF;) {
uint32_t fstart = lookup_table[0] | (lookup_table[1] << 16);
lookup_table += 2;
const uint16_t num_skip_entries = *lookup_table++;
for (uint16_t i = 0; i < num_skip_entries; ++i) {
// for each group of 16 fields we have: a
// bitmap of 16 bits a 16-bit field-entry
// offset for the first of them.
if (auto number = visit_bitmap(static_cast<uint16_t>(~*lookup_table),
fstart + 16 * i)) {
return *number;
}
lookup_table += 2;
}
}
Unreachable();
}
PROTOBUF_NOINLINE const char* TcParser::Error(PROTOBUF_TC_PARAM_NO_DATA_DECL) {
(void)ctx;
(void)ptr;
@ -1403,6 +1559,19 @@ PROTOBUF_ALWAYS_INLINE inline bool IsValidUTF8(ArenaStringPtr& field) {
}
void EnsureArenaStringIsNotDefault(const MessageLite* msg,
ArenaStringPtr* field) {
// If we failed here we might have left the string in its IsDefault state, but
// already set the has bit which breaks the message invariants. We must make
// it consistent again. We do that by guaranteeing the string always exists.
if (field->IsDefault()) {
field->Set("", msg->GetArena());
}
}
// The rest do nothing.
PROTOBUF_UNUSED void EnsureArenaStringIsNotDefault(const MessageLite* msg,
void*) {}
} // namespace
template <typename TagType, typename FieldType, TcParser::Utf8Type utf8>
@ -1423,6 +1592,7 @@ inline PROTOBUF_ALWAYS_INLINE const char* TcParser::SingularString(
ptr = ReadStringNoArena(msg, ptr, ctx, data.aux_idx(), table, field);
}
if (PROTOBUF_PREDICT_FALSE(ptr == nullptr)) {
EnsureArenaStringIsNotDefault(msg, &field);
PROTOBUF_MUSTTAIL return Error(PROTOBUF_TC_PARAM_NO_DATA_PASS);
}
switch (utf8) {
@ -2180,7 +2350,10 @@ PROTOBUF_NOINLINE const char* TcParser::MpString(PROTOBUF_TC_PARAM_DECL) {
std::string* str = field.MutableNoCopy(nullptr);
ptr = InlineGreedyStringParser(str, ptr, ctx);
}
if (!ptr) break;
if (ABSL_PREDICT_FALSE(ptr == nullptr)) {
EnsureArenaStringIsNotDefault(msg, &field);
break;
}
is_valid = MpVerifyUtf8(field.Get(), table, entry, xform_val);
break;
}

@ -11,14 +11,19 @@
#include <gmock/gmock.h>
#include <gtest/gtest.h>
#include "absl/algorithm/container.h"
#include "absl/log/absl_check.h"
#include "absl/log/absl_log.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/string_view.h"
#include "absl/types/optional.h"
#include "google/protobuf/descriptor.h"
#include "google/protobuf/descriptor_database.h"
#include "google/protobuf/descriptor_visitor.h"
#include "google/protobuf/generated_message_tctable_decl.h"
#include "google/protobuf/generated_message_tctable_impl.h"
#include "google/protobuf/io/coded_stream.h"
#include "google/protobuf/message_lite.h"
#include "google/protobuf/parse_context.h"
#include "google/protobuf/unittest.pb.h"
#include "google/protobuf/wire_format_lite.h"
@ -333,6 +338,10 @@ class FindFieldEntryTest : public ::testing::Test {
return TcParser::FieldName(&table.header, entry);
}
static int FieldNumber(const TcParseTableBase* table, size_t index) {
return TcParser::FieldNumber(table, table->field_entries_begin() + index);
}
// Calls the private `MessageName` function.
template <size_t kFastTableSizeLog2, size_t kNumEntries, size_t kNumFieldAux,
size_t kNameTableSize, size_t kFieldLookupTableSize>
@ -346,6 +355,36 @@ class FindFieldEntryTest : public ::testing::Test {
static constexpr int small_scan_size() { return TcParser::kMtSmallScanSize; }
};
TEST_F(FindFieldEntryTest, FieldNumberWorksForAllFields) {
// Look at all types registered in the binary and verify that field number
// calculation works for all the fields.
auto* gen_db = DescriptorPool::internal_generated_database();
std::vector<std::string> all_file_names;
gen_db->FindAllFileNames(&all_file_names);
for (const auto& filename : all_file_names) {
SCOPED_TRACE(filename);
const auto* file =
DescriptorPool::generated_pool()->FindFileByName(filename);
VisitDescriptors(*file, [&](const Descriptor& desc) {
SCOPED_TRACE(desc.full_name());
const auto* prototype =
MessageFactory::generated_factory()->GetPrototype(&desc);
const auto* tc_table = internal::GetClassData(*prototype)->tc_table;
std::vector<int> sorted_field_numbers;
for (auto* field : internal::FieldRange(&desc)) {
sorted_field_numbers.push_back(field->number());
}
absl::c_sort(sorted_field_numbers);
for (int i = 0; i < desc.field_count(); ++i) {
EXPECT_EQ(FieldNumber(tc_table, i), sorted_field_numbers[i]);
}
});
}
}
TEST_F(FindFieldEntryTest, SequentialFieldRange) {
// Look up fields that are within the range of `lookup_table_offset`.
// clang-format off

Loading…
Cancel
Save