Prepare UnknownFieldSet for replace `std::string` with `absl::string_view`:

- Add overloads that take `absl::Cord` and `std::string&&` as inputs, putting the burden in the implementation instead of users.
 - Add overload that returns `absl::Span<char>` for callers that need higher performance requirements where we can avoid copies altogether.
 - Hide the APIs that return `std::string*` when the breaking change is enabled (via `-D PROTOBUF_TEMPORARY_ENABLE_STRING_VIEW_RETURN_TYPE`).

PiperOrigin-RevId: 655600399
pull/17568/head
Protobuf Team Bot 4 months ago committed by Copybara-Service
parent 165d2c76ed
commit 75b66c2cde
  1. 2
      src/google/protobuf/arena_unittest.cc
  2. 7
      src/google/protobuf/descriptor.cc
  3. 18
      src/google/protobuf/parse_context.cc
  4. 10
      src/google/protobuf/parse_context.h
  5. 40
      src/google/protobuf/unknown_field_set.cc
  6. 48
      src/google/protobuf/unknown_field_set.h
  7. 88
      src/google/protobuf/unknown_field_set_unittest.cc
  8. 23
      src/google/protobuf/wire_format.cc

@ -652,7 +652,7 @@ TEST(ArenaTest, UnknownFields) {
arena_message_3->mutable_unknown_fields()->AddVarint(1000, 42);
arena_message_3->mutable_unknown_fields()->AddFixed32(1001, 42);
arena_message_3->mutable_unknown_fields()->AddFixed64(1002, 42);
arena_message_3->mutable_unknown_fields()->AddLengthDelimited(1003);
arena_message_3->mutable_unknown_fields()->AddLengthDelimited(1003, "");
arena_message_3->mutable_unknown_fields()->DeleteSubrange(0, 2);
arena_message_3->mutable_unknown_fields()->DeleteByNumber(1002);
arena_message_3->mutable_unknown_fields()->DeleteByNumber(1003);

@ -8921,11 +8921,12 @@ bool DescriptorBuilder::OptionInterpreter::InterpretSingleOption(
new UnknownFieldSet());
switch ((*iter)->type()) {
case FieldDescriptor::TYPE_MESSAGE: {
std::string* outstr =
parent_unknown_fields->AddLengthDelimited((*iter)->number());
ABSL_CHECK(unknown_fields->SerializeToString(outstr))
std::string outstr;
ABSL_CHECK(unknown_fields->SerializeToString(&outstr))
<< "Unexpected failure while serializing option submessage "
<< debug_msg_name << "\".";
parent_unknown_fields->AddLengthDelimited((*iter)->number(),
std::move(outstr));
break;
}

@ -12,6 +12,7 @@
#include "absl/strings/cord.h"
#include "absl/strings/string_view.h"
#include "absl/types/span.h"
#include "google/protobuf/message_lite.h"
#include "google/protobuf/repeated_field.h"
#include "google/protobuf/wire_format_lite.h"
@ -265,6 +266,23 @@ const char* EpsCopyInputStream::ReadCordFallback(const char* ptr, int size,
return ptr;
}
const char* EpsCopyInputStream::ReadCharsFallback(const char* ptr,
absl::Span<char> out) {
char* out_ptr = out.data();
ptr = AppendSize(ptr, out.size(), [&](const char* p, int s) {
memcpy(out_ptr, p, s);
out_ptr += s;
});
// If we had an error, set the leftover memory to make sure we don't leak
// uninit data in the object.
if (ABSL_PREDICT_FALSE(ptr == nullptr)) {
memset(out_ptr, 0xCD, out.data() + out.size() - out_ptr);
}
return ptr;
}
const char* EpsCopyInputStream::InitFrom(io::ZeroCopyInputStream* zcis) {
zcis_ = zcis;

@ -212,6 +212,15 @@ class PROTOBUF_EXPORT EpsCopyInputStream {
return ReadCordFallback(ptr, size, cord);
}
PROTOBUF_NODISCARD const char* ReadChars(const char* ptr,
absl::Span<char> out) {
if (out.size() <= static_cast<size_t>(buffer_end_ + kSlopBytes - ptr)) {
memcpy(out.data(), ptr, out.size());
return ptr + out.size();
}
return ReadCharsFallback(ptr, out);
}
template <typename Tag, typename T>
PROTOBUF_NODISCARD const char* ReadRepeatedFixed(const char* ptr,
@ -369,6 +378,7 @@ class PROTOBUF_EXPORT EpsCopyInputStream {
const char* AppendStringFallback(const char* ptr, int size, std::string* str);
const char* ReadStringFallback(const char* ptr, int size, std::string* str);
const char* ReadCordFallback(const char* ptr, int size, absl::Cord* cord);
const char* ReadCharsFallback(const char* ptr, absl::Span<char> out);
static bool ParseEndsInSlopRegion(const char* begin, int overrun, int depth);
bool StreamNext(const void** data) {
bool res = zcis_->Next(data, &size_);

@ -11,10 +11,15 @@
#include "google/protobuf/unknown_field_set.h"
#include <cstring>
#include <string>
#include <utility>
#include "absl/log/absl_check.h"
#include "absl/strings/cord.h"
#include "absl/strings/internal/resize_uninitialized.h"
#include "absl/strings/string_view.h"
#include "absl/types/span.h"
#include "google/protobuf/extension_set.h"
#include "google/protobuf/generated_message_tctable_impl.h"
#include "google/protobuf/io/coded_stream.h"
@ -117,6 +122,36 @@ void UnknownFieldSet::AddFixed64(int number, uint64_t value) {
field.data_.fixed64_ = value;
}
void UnknownFieldSet::AddLengthDelimited(int number, const absl::Cord& value) {
auto out = AddLengthDelimitedUninitialized(number, value.size());
for (absl::string_view part : value.Chunks()) {
memcpy(out.data(), part.data(), part.size());
out.remove_prefix(part.size());
}
}
absl::Span<char> UnknownFieldSet::AddLengthDelimitedUninitialized(int number,
size_t size) {
auto& field = *fields_.Add();
field.number_ = number;
field.SetType(UnknownField::TYPE_LENGTH_DELIMITED);
std::string* str = field.data_.string_value =
Arena::Create<std::string>(arena());
absl::strings_internal::STLStringResizeUninitialized(str, size);
return absl::Span<char>(*str);
}
template <int&...>
void UnknownFieldSet::AddLengthDelimited(int number, std::string&& value) {
auto& field = *fields_.Add();
field.number_ = number;
field.SetType(UnknownField::TYPE_LENGTH_DELIMITED);
field.data_.string_value =
Arena::Create<std::string>(arena(), std::move(value));
}
template void UnknownFieldSet::AddLengthDelimited(int, std::string&&);
#if !defined(PROTOBUF_FUTURE_STRING_VIEW_RETURN_TYPE)
std::string* UnknownFieldSet::AddLengthDelimited(int number) {
auto& field = *fields_.Add();
field.number_ = number;
@ -124,6 +159,7 @@ std::string* UnknownFieldSet::AddLengthDelimited(int number) {
field.data_.string_value = Arena::Create<std::string>(arena());
return field.data_.string_value;
}
#endif // PROTOBUF_FUTURE_STRING_VIEW_RETURN_TYPE
UnknownFieldSet* UnknownFieldSet::AddGroup(int number) {
auto& field = *fields_.Add();
@ -284,10 +320,10 @@ class UnknownFieldParserHelper {
}
const char* ParseLengthDelimited(uint32_t num, const char* ptr,
ParseContext* ctx) {
std::string* s = unknown_->AddLengthDelimited(num);
int size = ReadSize(&ptr);
GOOGLE_PROTOBUF_PARSER_ASSERT(ptr);
return ctx->ReadString(ptr, size, s);
return ctx->ReadChars(ptr,
unknown_->AddLengthDelimitedUninitialized(num, size));
}
const char* ParseGroup(uint32_t num, const char* ptr, ParseContext* ctx) {
return ctx->ParseGroupInlined(ptr, num * 8 + 3, [&](const char* ptr) {

@ -24,6 +24,7 @@
#include "absl/log/absl_check.h"
#include "absl/strings/cord.h"
#include "absl/strings/string_view.h"
#include "absl/types/span.h"
#include "google/protobuf/arena.h"
#include "google/protobuf/io/coded_stream.h"
#include "google/protobuf/io/zero_copy_stream_impl_lite.h"
@ -47,6 +48,8 @@ class InternalMetadata; // metadata_lite.h
class WireFormat; // wire_format.h
class MessageSetFieldSkipperUsingCord;
// extension_set_heavy.cc
class UnknownFieldParserHelper;
struct UnknownFieldSetTestPeer;
#if defined(PROTOBUF_FUTURE_STRING_VIEW_RETURN_TYPE)
using UFSStringView = absl::string_view;
@ -87,7 +90,13 @@ class PROTOBUF_EXPORT UnknownField {
inline void set_fixed32(uint32_t value);
inline void set_fixed64(uint64_t value);
inline void set_length_delimited(absl::string_view value);
// template to avoid ambiguous overload resolution.
template <int&...>
inline void set_length_delimited(std::string&& value);
inline void set_length_delimited(const absl::Cord& value);
#if !defined(PROTOBUF_FUTURE_STRING_VIEW_RETURN_TYPE)
inline std::string* mutable_length_delimited();
#endif // PROTOBUF_FUTURE_STRING_VIEW_RETURN_TYPE
inline UnknownFieldSet* mutable_group();
inline size_t GetLengthDelimitedSize() const;
@ -191,7 +200,14 @@ class PROTOBUF_EXPORT UnknownFieldSet {
void AddFixed32(int number, uint32_t value);
void AddFixed64(int number, uint64_t value);
void AddLengthDelimited(int number, absl::string_view value);
// template to avoid ambiguous overload resolution.
template <int&...>
void AddLengthDelimited(int number, std::string&& value);
void AddLengthDelimited(int number, const absl::Cord& value);
#if !defined(PROTOBUF_FUTURE_STRING_VIEW_RETURN_TYPE)
std::string* AddLengthDelimited(int number);
#endif // PROTOBUF_FUTURE_STRING_VIEW_RETURN_TYPE
UnknownFieldSet* AddGroup(int number);
// Adds an unknown field from another set.
@ -233,6 +249,10 @@ class PROTOBUF_EXPORT UnknownFieldSet {
: UnknownFieldSet(arena) {}
private:
friend internal::WireFormat;
friend internal::UnknownFieldParserHelper;
friend internal::UnknownFieldSetTestPeer;
using InternalArenaConstructable_ = void;
using DestructorSkippable_ = void;
@ -241,6 +261,14 @@ class PROTOBUF_EXPORT UnknownFieldSet {
Arena* arena() { return fields_.GetArena(); }
// Returns a buffer of `size` chars for the user to fill in.
// The buffer is potentially uninitialized memory. Failing to write to it
// might lead to undefined behavior when reading it later.
// Prefer the overloads above when possible. Calling this API without
// validating the `size` parameter can lead to unintentional memory usage and
// potential OOM.
absl::Span<char> AddLengthDelimitedUninitialized(int number, size_t size);
void ClearFallback();
void SwapSlow(UnknownFieldSet* other);
@ -275,7 +303,7 @@ inline void WriteVarint(uint32_t num, uint64_t val, UnknownFieldSet* unknown) {
}
inline void WriteLengthDelimited(uint32_t num, absl::string_view val,
UnknownFieldSet* unknown) {
unknown->AddLengthDelimited(num)->assign(val.data(), val.size());
unknown->AddLengthDelimited(num, val);
}
PROTOBUF_EXPORT
@ -331,7 +359,10 @@ inline UnknownField* UnknownFieldSet::mutable_field(int index) {
inline void UnknownFieldSet::AddLengthDelimited(int number,
const absl::string_view value) {
AddLengthDelimited(number)->assign(value.data(), value.size());
auto field = AddLengthDelimitedUninitialized(number, value.size());
if (!value.empty()) {
memcpy(field.data(), value.data(), value.size());
}
}
inline int UnknownField::number() const { return static_cast<int>(number_); }
@ -376,10 +407,21 @@ inline void UnknownField::set_length_delimited(const absl::string_view value) {
assert(type() == TYPE_LENGTH_DELIMITED);
data_.string_value->assign(value.data(), value.size());
}
template <int&...>
inline void UnknownField::set_length_delimited(std::string&& value) {
assert(type() == TYPE_LENGTH_DELIMITED);
*data_.string_value = std::move(value);
}
inline void UnknownField::set_length_delimited(const absl::Cord& value) {
assert(type() == TYPE_LENGTH_DELIMITED);
absl::CopyCordToString(value, data_.string_value);
}
#if !defined(PROTOBUF_FUTURE_STRING_VIEW_RETURN_TYPE)
inline std::string* UnknownField::mutable_length_delimited() {
assert(type() == TYPE_LENGTH_DELIMITED);
return data_.string_value;
}
#endif // PROTOBUF_FUTURE_STRING_VIEW_RETURN_TYPE
inline UnknownFieldSet* UnknownField::mutable_group() {
assert(type() == TYPE_GROUP);
return data_.group_;
@ -397,6 +439,8 @@ inline size_t UnknownField::GetLengthDelimitedSize() const {
inline void UnknownField::SetType(Type type) { type_ = type; }
extern template void UnknownFieldSet::AddLengthDelimited(int, std::string&&);
namespace internal {
// Add specialization of InternalMetadata::Container to provide arena support.

@ -14,6 +14,7 @@
#include "google/protobuf/unknown_field_set.h"
#include <cstddef>
#include <string>
#include <vector>
@ -39,10 +40,27 @@
namespace google {
namespace protobuf {
namespace internal {
struct UnknownFieldSetTestPeer {
static auto AddLengthDelimitedUninitialized(UnknownFieldSet& set, int number,
size_t length) {
return set.AddLengthDelimitedUninitialized(number, length);
}
};
} // namespace internal
using internal::WireFormat;
using ::testing::ElementsAre;
template <typename T>
T UnknownToProto(const UnknownFieldSet& set) {
T message;
std::string serialized_message;
ABSL_CHECK(set.SerializeToString(&serialized_message));
ABSL_CHECK(message.ParseFromString(serialized_message));
return message;
}
class UnknownFieldSetTest : public testing::Test {
protected:
void SetUp() override {
@ -181,7 +199,13 @@ static void PopulateUFS(UnknownFieldSet& set) {
node->AddVarint(1, 100);
const char* long_str = "This is a very long string, not sso";
node->AddLengthDelimited(2, long_str);
node->AddLengthDelimited(2, std::string(long_str));
node->AddLengthDelimited(2, absl::Cord(long_str));
internal::UnknownFieldSetTestPeer::AddLengthDelimitedUninitialized(*node, 2,
100);
#if !defined(PROTOBUF_FUTURE_STRING_VIEW_RETURN_TYPE)
*node->AddLengthDelimited(3) = long_str;
#endif // PROTOBUF_FUTURE_STRING_VIEW_RETURN_TYPE
// Test some recursion too.
node = node->AddGroup(4);
}
@ -643,12 +667,20 @@ TEST_F(UnknownFieldSetTest, SpaceUsed) {
shadow_vector.Add();
EXPECT_EQ(total(), empty_message.SpaceUsedLong()) << "Var";
#if !defined(PROTOBUF_FUTURE_STRING_VIEW_RETURN_TYPE)
str = unknown_fields->AddLengthDelimited(1);
shadow_vector.Add();
EXPECT_EQ(total(), empty_message.SpaceUsedLong()) << "Str";
str->assign(sizeof(std::string) + 1, 'x');
EXPECT_EQ(total(), empty_message.SpaceUsedLong()) << "Str2";
#else
std::string fake_str(31, 'a');
str = &fake_str;
unknown_fields->AddLengthDelimited(1, fake_str);
shadow_vector.Add();
EXPECT_EQ(total(), empty_message.SpaceUsedLong()) << "Str2";
#endif // PROTOBUF_FUTURE_STRING_VIEW_RETURN_TYPE
group = unknown_fields->AddGroup(1);
shadow_vector.Add();
@ -751,7 +783,14 @@ TEST_F(UnknownFieldSetTest, SerializeToString) {
field_set.AddVarint(1, -1);
field_set.AddVarint(2, -2);
field_set.AddLengthDelimited(44, "str");
field_set.AddLengthDelimited(44, "byv");
field_set.AddLengthDelimited(44, std::string("byv"));
field_set.AddLengthDelimited(44,
absl::Cord("this came from cord and is long"));
memcpy(internal::UnknownFieldSetTestPeer::AddLengthDelimitedUninitialized(
field_set, 44, 10)
.data(),
"0123456789", 10);
field_set.AddFixed32(7, 7);
field_set.AddFixed64(8, 8);
@ -760,10 +799,8 @@ TEST_F(UnknownFieldSetTest, SerializeToString) {
group_field_set = field_set.AddGroup(46);
group_field_set->AddVarint(47, 2048);
unittest::TestAllTypes message;
std::string serialized_message;
ASSERT_TRUE(field_set.SerializeToString(&serialized_message));
ASSERT_TRUE(message.ParseFromString(serialized_message));
unittest::TestAllTypes message =
UnknownToProto<unittest::TestAllTypes>(field_set);
EXPECT_EQ(message.optional_int32(), -1);
EXPECT_EQ(message.optional_int64(), -2);
@ -771,8 +808,9 @@ TEST_F(UnknownFieldSetTest, SerializeToString) {
EXPECT_EQ(message.optional_uint64(), 4);
EXPECT_EQ(message.optional_fixed32(), 7);
EXPECT_EQ(message.optional_fixed64(), 8);
EXPECT_EQ(message.repeated_string(0), "str");
EXPECT_EQ(message.repeated_string(1), "byv");
EXPECT_THAT(message.repeated_string(),
ElementsAre("str", "byv", "this came from cord and is long",
"0123456789"));
EXPECT_EQ(message.repeatedgroup(0).a(), 1024);
EXPECT_EQ(message.repeatedgroup(1).a(), 2048);
}
@ -818,6 +856,42 @@ TEST_F(UnknownFieldSetTest, SerializeToCord_TestPackedTypes) {
EXPECT_THAT(message.packed_uint64(), ElementsAre(5, 6, 7));
}
TEST(UnknownFieldTest, SettersOverrideTheDataProperly) {
using T = unittest::TestAllTypes;
UnknownFieldSet set;
set.AddVarint(T::kOptionalInt32FieldNumber, 2);
set.AddFixed32(T::kOptionalFixed32FieldNumber, 3);
set.AddFixed64(T::kOptionalFixed64FieldNumber, 4);
set.AddLengthDelimited(T::kOptionalStringFieldNumber, "5");
T message = UnknownToProto<T>(set);
EXPECT_EQ(message.optional_int32(), 2);
EXPECT_EQ(message.optional_fixed32(), 3);
EXPECT_EQ(message.optional_fixed64(), 4);
EXPECT_EQ(message.optional_string(), "5");
set.mutable_field(0)->set_varint(22);
set.mutable_field(1)->set_fixed32(33);
set.mutable_field(2)->set_fixed64(44);
set.mutable_field(3)->set_length_delimited("55");
message = UnknownToProto<T>(set);
EXPECT_EQ(message.optional_int32(), 22);
EXPECT_EQ(message.optional_fixed32(), 33);
EXPECT_EQ(message.optional_fixed64(), 44);
EXPECT_EQ(message.optional_string(), "55");
set.mutable_field(3)->set_length_delimited(std::string("555"));
message = UnknownToProto<T>(set);
EXPECT_EQ(message.optional_string(), "555");
set.mutable_field(3)->set_length_delimited(absl::Cord("5555"));
message = UnknownToProto<T>(set);
EXPECT_EQ(message.optional_string(), "5555");
}
} // namespace
} // namespace protobuf
} // namespace google

@ -13,6 +13,7 @@
#include <stack>
#include <string>
#include <utility>
#include <vector>
#include "absl/log/absl_check.h"
@ -84,9 +85,21 @@ bool WireFormat::SkipField(io::CodedInputStream* input, uint32_t tag,
if (!input->ReadVarint32(&length)) return false;
if (unknown_fields == nullptr) {
if (!input->Skip(length)) return false;
} else if (length > 1'000'000) {
// If the provided length is too long, use the `std::string` approach,
// which will grow as it reads data instead of allocating all at the
// beginning. This protects against malformed input.
// Any reasonable value here is fine.
std::string str;
if (!input->ReadString(&str, length)) {
return false;
}
unknown_fields->AddLengthDelimited(number, std::move(str));
} else {
if (!input->ReadString(unknown_fields->AddLengthDelimited(number),
length)) {
if (!input->ReadRaw(
unknown_fields->AddLengthDelimitedUninitialized(number, length)
.data(),
length)) {
return false;
}
}
@ -362,8 +375,10 @@ bool WireFormat::SkipMessageSetField(io::CodedInputStream* input,
UnknownFieldSet* unknown_fields) {
uint32_t length;
if (!input->ReadVarint32(&length)) return false;
return input->ReadString(unknown_fields->AddLengthDelimited(field_number),
length);
return input->ReadRaw(
unknown_fields->AddLengthDelimitedUninitialized(field_number, length)
.data(),
length);
}
bool WireFormat::ParseAndMergeMessageSetField(uint32_t field_number,

Loading…
Cancel
Save