Apply patch

pull/10546/head
Deanna Garcia 2 years ago
parent db38a8c2da
commit 6cf8108271
  1. 27
      src/google/protobuf/extension_set_inl.h
  2. 26
      src/google/protobuf/wire_format.cc
  3. 27
      src/google/protobuf/wire_format_lite.h
  4. 104
      src/google/protobuf/wire_format_unittest.inc

@ -206,16 +206,21 @@ const char* ExtensionSet::ParseMessageSetItemTmpl(
const char* ptr, const Msg* extendee, internal::InternalMetadata* metadata, const char* ptr, const Msg* extendee, internal::InternalMetadata* metadata,
internal::ParseContext* ctx) { internal::ParseContext* ctx) {
std::string payload; std::string payload;
uint32_t type_id = 0; uint32_t type_id;
bool payload_read = false; enum class State { kNoTag, kHasType, kHasPayload, kDone };
State state = State::kNoTag;
while (!ctx->Done(&ptr)) { while (!ctx->Done(&ptr)) {
uint32_t tag = static_cast<uint8_t>(*ptr++); uint32_t tag = static_cast<uint8_t>(*ptr++);
if (tag == WireFormatLite::kMessageSetTypeIdTag) { if (tag == WireFormatLite::kMessageSetTypeIdTag) {
uint64_t tmp; uint64_t tmp;
ptr = ParseBigVarint(ptr, &tmp); ptr = ParseBigVarint(ptr, &tmp);
GOOGLE_PROTOBUF_PARSER_ASSERT(ptr); GOOGLE_PROTOBUF_PARSER_ASSERT(ptr);
type_id = tmp; if (state == State::kNoTag) {
if (payload_read) { type_id = tmp;
state = State::kHasType;
} else if (state == State::kHasPayload) {
type_id = tmp;
ExtensionInfo extension; ExtensionInfo extension;
bool was_packed_on_wire; bool was_packed_on_wire;
if (!FindExtension(2, type_id, extendee, ctx, &extension, if (!FindExtension(2, type_id, extendee, ctx, &extension,
@ -241,20 +246,24 @@ const char* ExtensionSet::ParseMessageSetItemTmpl(
GOOGLE_PROTOBUF_PARSER_ASSERT(value->_InternalParse(p, &tmp_ctx) && GOOGLE_PROTOBUF_PARSER_ASSERT(value->_InternalParse(p, &tmp_ctx) &&
tmp_ctx.EndedAtLimit()); tmp_ctx.EndedAtLimit());
} }
type_id = 0; state = State::kDone;
} }
} else if (tag == WireFormatLite::kMessageSetMessageTag) { } else if (tag == WireFormatLite::kMessageSetMessageTag) {
if (type_id != 0) { if (state == State::kHasType) {
ptr = ParseFieldMaybeLazily(static_cast<uint64_t>(type_id) * 8 + 2, ptr, ptr = ParseFieldMaybeLazily(static_cast<uint64_t>(type_id) * 8 + 2, ptr,
extendee, metadata, ctx); extendee, metadata, ctx);
GOOGLE_PROTOBUF_PARSER_ASSERT(ptr != nullptr); GOOGLE_PROTOBUF_PARSER_ASSERT(ptr != nullptr);
type_id = 0; state = State::kDone;
} else { } else {
std::string tmp;
int32_t size = ReadSize(&ptr); int32_t size = ReadSize(&ptr);
GOOGLE_PROTOBUF_PARSER_ASSERT(ptr); GOOGLE_PROTOBUF_PARSER_ASSERT(ptr);
ptr = ctx->ReadString(ptr, size, &payload); ptr = ctx->ReadString(ptr, size, &tmp);
GOOGLE_PROTOBUF_PARSER_ASSERT(ptr); GOOGLE_PROTOBUF_PARSER_ASSERT(ptr);
payload_read = true; if (state == State::kNoTag) {
payload = std::move(tmp);
state = State::kHasPayload;
}
} }
} else { } else {
ptr = ReadTag(ptr - 1, &tag); ptr = ReadTag(ptr - 1, &tag);

@ -657,9 +657,11 @@ struct WireFormat::MessageSetParser {
const char* _InternalParse(const char* ptr, internal::ParseContext* ctx) { const char* _InternalParse(const char* ptr, internal::ParseContext* ctx) {
// Parse a MessageSetItem // Parse a MessageSetItem
auto metadata = reflection->MutableInternalMetadata(msg); auto metadata = reflection->MutableInternalMetadata(msg);
enum class State { kNoTag, kHasType, kHasPayload, kDone };
State state = State::kNoTag;
std::string payload; std::string payload;
uint32_t type_id = 0; uint32_t type_id = 0;
bool payload_read = false;
while (!ctx->Done(&ptr)) { while (!ctx->Done(&ptr)) {
// We use 64 bit tags in order to allow typeid's that span the whole // We use 64 bit tags in order to allow typeid's that span the whole
// range of 32 bit numbers. // range of 32 bit numbers.
@ -668,8 +670,11 @@ struct WireFormat::MessageSetParser {
uint64_t tmp; uint64_t tmp;
ptr = ParseBigVarint(ptr, &tmp); ptr = ParseBigVarint(ptr, &tmp);
GOOGLE_PROTOBUF_PARSER_ASSERT(ptr); GOOGLE_PROTOBUF_PARSER_ASSERT(ptr);
type_id = tmp; if (state == State::kNoTag) {
if (payload_read) { type_id = tmp;
state = State::kHasType;
} else if (state == State::kHasPayload) {
type_id = tmp;
const FieldDescriptor* field; const FieldDescriptor* field;
if (ctx->data().pool == nullptr) { if (ctx->data().pool == nullptr) {
field = reflection->FindKnownExtensionByNumber(type_id); field = reflection->FindKnownExtensionByNumber(type_id);
@ -696,17 +701,17 @@ struct WireFormat::MessageSetParser {
GOOGLE_PROTOBUF_PARSER_ASSERT(value->_InternalParse(p, &tmp_ctx) && GOOGLE_PROTOBUF_PARSER_ASSERT(value->_InternalParse(p, &tmp_ctx) &&
tmp_ctx.EndedAtLimit()); tmp_ctx.EndedAtLimit());
} }
type_id = 0; state = State::kDone;
} }
continue; continue;
} else if (tag == WireFormatLite::kMessageSetMessageTag) { } else if (tag == WireFormatLite::kMessageSetMessageTag) {
if (type_id == 0) { if (state == State::kNoTag) {
int32_t size = ReadSize(&ptr); int32_t size = ReadSize(&ptr);
GOOGLE_PROTOBUF_PARSER_ASSERT(ptr); GOOGLE_PROTOBUF_PARSER_ASSERT(ptr);
ptr = ctx->ReadString(ptr, size, &payload); ptr = ctx->ReadString(ptr, size, &payload);
GOOGLE_PROTOBUF_PARSER_ASSERT(ptr); GOOGLE_PROTOBUF_PARSER_ASSERT(ptr);
payload_read = true; state = State::kHasPayload;
} else { } else if (state == State::kHasType) {
// We're now parsing the payload // We're now parsing the payload
const FieldDescriptor* field = nullptr; const FieldDescriptor* field = nullptr;
if (descriptor->IsExtensionNumber(type_id)) { if (descriptor->IsExtensionNumber(type_id)) {
@ -720,7 +725,12 @@ struct WireFormat::MessageSetParser {
ptr = WireFormat::_InternalParseAndMergeField( ptr = WireFormat::_InternalParseAndMergeField(
msg, ptr, ctx, static_cast<uint64_t>(type_id) * 8 + 2, reflection, msg, ptr, ctx, static_cast<uint64_t>(type_id) * 8 + 2, reflection,
field); field);
type_id = 0; state = State::kDone;
} else {
int32_t size = ReadSize(&ptr);
GOOGLE_PROTOBUF_PARSER_ASSERT(ptr);
ptr = ctx->Skip(ptr, size);
GOOGLE_PROTOBUF_PARSER_ASSERT(ptr);
} }
} else { } else {
// An unknown field in MessageSetItem. // An unknown field in MessageSetItem.

@ -1834,6 +1834,9 @@ bool ParseMessageSetItemImpl(io::CodedInputStream* input, MS ms) {
// we can parse it later. // we can parse it later.
std::string message_data; std::string message_data;
enum class State { kNoTag, kHasType, kHasPayload, kDone };
State state = State::kNoTag;
while (true) { while (true) {
const uint32_t tag = input->ReadTagNoLastTag(); const uint32_t tag = input->ReadTagNoLastTag();
if (tag == 0) return false; if (tag == 0) return false;
@ -1842,26 +1845,34 @@ bool ParseMessageSetItemImpl(io::CodedInputStream* input, MS ms) {
case WireFormatLite::kMessageSetTypeIdTag: { case WireFormatLite::kMessageSetTypeIdTag: {
uint32_t type_id; uint32_t type_id;
if (!input->ReadVarint32(&type_id)) return false; if (!input->ReadVarint32(&type_id)) return false;
last_type_id = type_id; if (state == State::kNoTag) {
last_type_id = type_id;
if (!message_data.empty()) { state = State::kHasType;
} else if (state == State::kHasPayload) {
// We saw some message data before the type_id. Have to parse it // We saw some message data before the type_id. Have to parse it
// now. // now.
io::CodedInputStream sub_input( io::CodedInputStream sub_input(
reinterpret_cast<const uint8_t*>(message_data.data()), reinterpret_cast<const uint8_t*>(message_data.data()),
static_cast<int>(message_data.size())); static_cast<int>(message_data.size()));
sub_input.SetRecursionLimit(input->RecursionBudget()); sub_input.SetRecursionLimit(input->RecursionBudget());
if (!ms.ParseField(last_type_id, &sub_input)) { if (!ms.ParseField(type_id, &sub_input)) {
return false; return false;
} }
message_data.clear(); message_data.clear();
state = State::kDone;
} }
break; break;
} }
case WireFormatLite::kMessageSetMessageTag: { case WireFormatLite::kMessageSetMessageTag: {
if (last_type_id == 0) { if (state == State::kHasType) {
// Already saw type_id, so we can parse this directly.
if (!ms.ParseField(last_type_id, input)) {
return false;
}
state = State::kDone;
} else if (state == State::kNoTag) {
// We haven't seen a type_id yet. Append this data to message_data. // We haven't seen a type_id yet. Append this data to message_data.
uint32_t length; uint32_t length;
if (!input->ReadVarint32(&length)) return false; if (!input->ReadVarint32(&length)) return false;
@ -1872,11 +1883,9 @@ bool ParseMessageSetItemImpl(io::CodedInputStream* input, MS ms) {
auto ptr = reinterpret_cast<uint8_t*>(&message_data[0]); auto ptr = reinterpret_cast<uint8_t*>(&message_data[0]);
ptr = io::CodedOutputStream::WriteVarint32ToArray(length, ptr); ptr = io::CodedOutputStream::WriteVarint32ToArray(length, ptr);
if (!input->ReadRaw(ptr, length)) return false; if (!input->ReadRaw(ptr, length)) return false;
state = State::kHasPayload;
} else { } else {
// Already saw type_id, so we can parse this directly. if (!ms.SkipField(tag, input)) return false;
if (!ms.ParseField(last_type_id, input)) {
return false;
}
} }
break; break;

@ -581,28 +581,54 @@ TEST(WireFormatTest, ParseMessageSet) {
EXPECT_EQ(message_set.DebugString(), dynamic_message_set.DebugString()); EXPECT_EQ(message_set.DebugString(), dynamic_message_set.DebugString());
} }
TEST(WireFormatTest, ParseMessageSetWithReverseTagOrder) { namespace {
std::string BuildMessageSetItemStart() {
std::string data; std::string data;
{ {
UNITTEST::TestMessageSetExtension1 message;
message.set_i(123);
// Build a MessageSet manually with its message content put before its
// type_id.
io::StringOutputStream output_stream(&data); io::StringOutputStream output_stream(&data);
io::CodedOutputStream coded_output(&output_stream); io::CodedOutputStream coded_output(&output_stream);
coded_output.WriteTag(WireFormatLite::kMessageSetItemStartTag); coded_output.WriteTag(WireFormatLite::kMessageSetItemStartTag);
}
return data;
}
std::string BuildMessageSetItemEnd() {
std::string data;
{
io::StringOutputStream output_stream(&data);
io::CodedOutputStream coded_output(&output_stream);
coded_output.WriteTag(WireFormatLite::kMessageSetItemEndTag);
}
return data;
}
std::string BuildMessageSetTestExtension1(int value = 123) {
std::string data;
{
UNITTEST::TestMessageSetExtension1 message;
message.set_i(value);
io::StringOutputStream output_stream(&data);
io::CodedOutputStream coded_output(&output_stream);
// Write the message content first. // Write the message content first.
WireFormatLite::WriteTag(WireFormatLite::kMessageSetMessageNumber, WireFormatLite::WriteTag(WireFormatLite::kMessageSetMessageNumber,
WireFormatLite::WIRETYPE_LENGTH_DELIMITED, WireFormatLite::WIRETYPE_LENGTH_DELIMITED,
&coded_output); &coded_output);
coded_output.WriteVarint32(message.ByteSizeLong()); coded_output.WriteVarint32(message.ByteSizeLong());
message.SerializeWithCachedSizes(&coded_output); message.SerializeWithCachedSizes(&coded_output);
// Write the type id. }
uint32_t type_id = message.GetDescriptor()->extension(0)->number(); return data;
}
std::string BuildMessageSetItemTypeId(int extension_number) {
std::string data;
{
io::StringOutputStream output_stream(&data);
io::CodedOutputStream coded_output(&output_stream);
WireFormatLite::WriteUInt32(WireFormatLite::kMessageSetTypeIdNumber, WireFormatLite::WriteUInt32(WireFormatLite::kMessageSetTypeIdNumber,
type_id, &coded_output); extension_number, &coded_output);
coded_output.WriteTag(WireFormatLite::kMessageSetItemEndTag);
} }
return data;
}
void ValidateTestMessageSet(const std::string& test_case,
const std::string& data) {
SCOPED_TRACE(test_case);
{ {
PROTO2_WIREFORMAT_UNITTEST::TestMessageSet message_set; PROTO2_WIREFORMAT_UNITTEST::TestMessageSet message_set;
ASSERT_TRUE(message_set.ParseFromString(data)); ASSERT_TRUE(message_set.ParseFromString(data));
@ -612,6 +638,11 @@ TEST(WireFormatTest, ParseMessageSetWithReverseTagOrder) {
.GetExtension( .GetExtension(
UNITTEST::TestMessageSetExtension1::message_set_extension) UNITTEST::TestMessageSetExtension1::message_set_extension)
.i()); .i());
// Make sure it does not contain anything else.
message_set.ClearExtension(
UNITTEST::TestMessageSetExtension1::message_set_extension);
EXPECT_EQ(message_set.SerializeAsString(), "");
} }
{ {
// Test parse the message via Reflection. // Test parse the message via Reflection.
@ -627,6 +658,61 @@ TEST(WireFormatTest, ParseMessageSetWithReverseTagOrder) {
UNITTEST::TestMessageSetExtension1::message_set_extension) UNITTEST::TestMessageSetExtension1::message_set_extension)
.i()); .i());
} }
{
// Test parse the message via DynamicMessage.
DynamicMessageFactory factory;
std::unique_ptr<Message> msg(
factory
.GetPrototype(
PROTO2_WIREFORMAT_UNITTEST::TestMessageSet::descriptor())
->New());
msg->ParseFromString(data);
auto* reflection = msg->GetReflection();
std::vector<const FieldDescriptor*> fields;
reflection->ListFields(*msg, &fields);
ASSERT_EQ(fields.size(), 1);
const auto& sub = reflection->GetMessage(*msg, fields[0]);
reflection = sub.GetReflection();
EXPECT_EQ(123, reflection->GetInt32(
sub, sub.GetDescriptor()->FindFieldByName("i")));
}
}
} // namespace
TEST(WireFormatTest, ParseMessageSetWithAnyTagOrder) {
std::string start = BuildMessageSetItemStart();
std::string end = BuildMessageSetItemEnd();
std::string id = BuildMessageSetItemTypeId(
UNITTEST::TestMessageSetExtension1::descriptor()->extension(0)->number());
std::string message = BuildMessageSetTestExtension1();
ValidateTestMessageSet("id + message", start + id + message + end);
ValidateTestMessageSet("message + id", start + message + id + end);
}
TEST(WireFormatTest, ParseMessageSetWithDuplicateTags) {
std::string start = BuildMessageSetItemStart();
std::string end = BuildMessageSetItemEnd();
std::string id = BuildMessageSetItemTypeId(
UNITTEST::TestMessageSetExtension1::descriptor()->extension(0)->number());
std::string other_id = BuildMessageSetItemTypeId(123456);
std::string message = BuildMessageSetTestExtension1();
std::string other_message = BuildMessageSetTestExtension1(321);
// Double id
ValidateTestMessageSet("id + other_id + message",
start + id + other_id + message + end);
ValidateTestMessageSet("id + message + other_id",
start + id + message + other_id + end);
ValidateTestMessageSet("message + id + other_id",
start + message + id + other_id + end);
// Double message
ValidateTestMessageSet("id + message + other_message",
start + id + message + other_message + end);
ValidateTestMessageSet("message + id + other_message",
start + message + id + other_message + end);
ValidateTestMessageSet("message + other_message + id",
start + message + other_message + id + end);
} }
void SerializeReverseOrder( void SerializeReverseOrder(

Loading…
Cancel
Save