diff --git a/src/google/protobuf/extension_set_inl.h b/src/google/protobuf/extension_set_inl.h index 4554d95083..4fdc3991f4 100644 --- a/src/google/protobuf/extension_set_inl.h +++ b/src/google/protobuf/extension_set_inl.h @@ -215,7 +215,8 @@ const char* ExtensionSet::ParseMessageSetItemTmpl( if (tag == WireFormatLite::kMessageSetTypeIdTag) { uint64_t tmp; ptr = ParseBigVarint(ptr, &tmp); - GOOGLE_PROTOBUF_PARSER_ASSERT(ptr); + // We should fail parsing if type id is 0. + GOOGLE_PROTOBUF_PARSER_ASSERT(ptr != nullptr && tmp != 0); if (state == State::kNoTag) { type_id = tmp; state = State::kHasType; diff --git a/src/google/protobuf/wire_format.cc b/src/google/protobuf/wire_format.cc index 0e7fedfd33..5359f4e434 100644 --- a/src/google/protobuf/wire_format.cc +++ b/src/google/protobuf/wire_format.cc @@ -671,7 +671,8 @@ struct WireFormat::MessageSetParser { if (tag == WireFormatLite::kMessageSetTypeIdTag) { uint64_t tmp; ptr = ParseBigVarint(ptr, &tmp); - GOOGLE_PROTOBUF_PARSER_ASSERT(ptr); + // We should fail parsing if type id is 0. + GOOGLE_PROTOBUF_PARSER_ASSERT(ptr != nullptr && tmp != 0); if (state == State::kNoTag) { type_id = tmp; state = State::kHasType; diff --git a/src/google/protobuf/wire_format_lite.h b/src/google/protobuf/wire_format_lite.h index efe48ce161..3097a14f51 100644 --- a/src/google/protobuf/wire_format_lite.h +++ b/src/google/protobuf/wire_format_lite.h @@ -1890,7 +1890,8 @@ bool ParseMessageSetItemImpl(io::CodedInputStream* input, MS ms) { switch (tag) { case WireFormatLite::kMessageSetTypeIdTag: { uint32_t type_id; - if (!input->ReadVarint32(&type_id)) return false; + // We should fail parsing if type id is 0. + if (!input->ReadVarint32(&type_id) || type_id == 0) return false; if (state == State::kNoTag) { last_type_id = type_id; state = State::kHasType; diff --git a/src/google/protobuf/wire_format_unittest.inc b/src/google/protobuf/wire_format_unittest.inc index 4c7d047c71..45a74bc964 100644 --- a/src/google/protobuf/wire_format_unittest.inc +++ b/src/google/protobuf/wire_format_unittest.inc @@ -579,6 +579,31 @@ TEST(WireFormatTest, ParseMessageSet) { EXPECT_EQ(message_set.DebugString(), dynamic_message_set.DebugString()); } +TEST(WireFormatTest, MessageSetUnknownButValidTypeId) { + const char encoded[] = { + 013, // 1: SGROUP + 032, 2, // 3:LEN 2 + 010, 0, // 1:0 + 020, 4, // 2:4 + 014 // 1: EGROUP + }; + PROTO2_WIREFORMAT_UNITTEST::TestMessageSet message; + EXPECT_TRUE(message.ParseFromArray(encoded, sizeof(encoded))); +} + +TEST(WireFormatTest, MessageSetInvalidTypeId) { + // "type_id" is 0 and should fail to parse. + const char encoded[] = { + 013, // 1: SGROUP + 032, 2, // 3:LEN 2 + 010, 0, // 1:0 + 020, 0, // 2:0 + 014 // 1: EGROUP + }; + PROTO2_WIREFORMAT_UNITTEST::TestMessageSet message; + EXPECT_FALSE(message.ParseFromArray(encoded, sizeof(encoded))); +} + namespace { std::string BuildMessageSetItemStart() { std::string data;