diff --git a/src/google/protobuf/compiler/cpp/parse_function_generator.cc b/src/google/protobuf/compiler/cpp/parse_function_generator.cc index 791a7d5503..a296d411e7 100644 --- a/src/google/protobuf/compiler/cpp/parse_function_generator.cc +++ b/src/google/protobuf/compiler/cpp/parse_function_generator.cc @@ -216,17 +216,6 @@ void ParseFunctionGenerator::GenerateTailcallParseFunction(Formatter& format) { "}\n\n"); } -static bool NeedsUnknownEnumSupport(const Descriptor* descriptor) { - for (int i = 0; i < descriptor->field_count(); ++i) { - auto* field = descriptor->field(i); - if (field->is_repeated() && field->cpp_type() == field->CPPTYPE_ENUM && - !internal::cpp::HasPreservingUnknownEnumSemantics(field)) { - return true; - } - } - return false; -} - void ParseFunctionGenerator::GenerateTailcallFallbackFunction( Formatter& format) { ABSL_CHECK(should_generate_tctable()); @@ -236,21 +225,18 @@ void ParseFunctionGenerator::GenerateTailcallFallbackFunction( format.Indent(); format("auto* typed_msg = static_cast<$classname$*>(msg);\n"); - // If we need a side channel, generate the check to jump to the generic - // handler to deal with the side channel data. - if (NeedsUnknownEnumSupport(descriptor_)) { - format( - "if (PROTOBUF_PREDICT_FALSE(\n" - " _pbi::TcParser::MustFallbackToGeneric(PROTOBUF_TC_PARAM_PASS))) " - "{\n" - " PROTOBUF_MUSTTAIL return " - "::_pbi::TcParser::GenericFallback$1$(PROTOBUF_TC_PARAM_PASS);\n" - "}\n", - GetOptimizeFor(descriptor_->file(), options_) == - FileOptions::LITE_RUNTIME - ? "Lite" - : ""); - } + // Generate the check to jump to the generic handler to deal with the side + // channel data. + format( + "if (PROTOBUF_PREDICT_FALSE(\n" + " _pbi::TcParser::MustFallbackToGeneric(PROTOBUF_TC_PARAM_PASS))) " + "{\n" + " PROTOBUF_MUSTTAIL return " + "::_pbi::TcParser::GenericFallback$1$(PROTOBUF_TC_PARAM_PASS);\n" + "}\n", + GetOptimizeFor(descriptor_->file(), options_) == FileOptions::LITE_RUNTIME + ? "Lite" + : ""); if (num_hasbits_ > 0) { // Sync hasbits @@ -1147,7 +1133,7 @@ void ParseFunctionGenerator::GenerateFieldBody( format.Set("enum_type", QualifiedClassName(field->enum_type(), options_)); format( - "$uint32$ val = ::$proto_ns$::internal::ReadVarint32(&ptr);\n" + "$int32$ val = ::$proto_ns$::internal::ReadVarint32(&ptr);\n" "CHK_(ptr);\n"); if (!internal::cpp::HasPreservingUnknownEnumSemantics(field)) { format( diff --git a/src/google/protobuf/descriptor.pb.cc b/src/google/protobuf/descriptor.pb.cc index ce8901ec0e..88e9833a48 100644 --- a/src/google/protobuf/descriptor.pb.cc +++ b/src/google/protobuf/descriptor.pb.cc @@ -4252,7 +4252,7 @@ const char* FieldDescriptorProto::_InternalParse(const char* ptr, ::_pbi::ParseC // optional .google.protobuf.FieldDescriptorProto.Label label = 4; case 4: if (PROTOBUF_PREDICT_TRUE(static_cast<::uint8_t>(tag) == 32)) { - ::uint32_t val = ::PROTOBUF_NAMESPACE_ID::internal::ReadVarint32(&ptr); + ::int32_t val = ::PROTOBUF_NAMESPACE_ID::internal::ReadVarint32(&ptr); CHK_(ptr); if (PROTOBUF_PREDICT_TRUE(::PROTOBUF_NAMESPACE_ID::FieldDescriptorProto_Label_IsValid(static_cast(val)))) { _internal_set_label(static_cast<::PROTOBUF_NAMESPACE_ID::FieldDescriptorProto_Label>(val)); @@ -4266,7 +4266,7 @@ const char* FieldDescriptorProto::_InternalParse(const char* ptr, ::_pbi::ParseC // optional .google.protobuf.FieldDescriptorProto.Type type = 5; case 5: if (PROTOBUF_PREDICT_TRUE(static_cast<::uint8_t>(tag) == 40)) { - ::uint32_t val = ::PROTOBUF_NAMESPACE_ID::internal::ReadVarint32(&ptr); + ::int32_t val = ::PROTOBUF_NAMESPACE_ID::internal::ReadVarint32(&ptr); CHK_(ptr); if (PROTOBUF_PREDICT_TRUE(::PROTOBUF_NAMESPACE_ID::FieldDescriptorProto_Type_IsValid(static_cast(val)))) { _internal_set_type(static_cast<::PROTOBUF_NAMESPACE_ID::FieldDescriptorProto_Type>(val)); @@ -7026,7 +7026,7 @@ const char* FileOptions::_InternalParse(const char* ptr, ::_pbi::ParseContext* c // optional .google.protobuf.FileOptions.OptimizeMode optimize_for = 9 [default = SPEED]; case 9: if (PROTOBUF_PREDICT_TRUE(static_cast<::uint8_t>(tag) == 72)) { - ::uint32_t val = ::PROTOBUF_NAMESPACE_ID::internal::ReadVarint32(&ptr); + ::int32_t val = ::PROTOBUF_NAMESPACE_ID::internal::ReadVarint32(&ptr); CHK_(ptr); if (PROTOBUF_PREDICT_TRUE(::PROTOBUF_NAMESPACE_ID::FileOptions_OptimizeMode_IsValid(static_cast(val)))) { _internal_set_optimize_for(static_cast<::PROTOBUF_NAMESPACE_ID::FileOptions_OptimizeMode>(val)); @@ -8297,7 +8297,7 @@ const char* FieldOptions::_InternalParse(const char* ptr, ::_pbi::ParseContext* // optional .google.protobuf.FieldOptions.CType ctype = 1 [default = STRING]; case 1: if (PROTOBUF_PREDICT_TRUE(static_cast<::uint8_t>(tag) == 8)) { - ::uint32_t val = ::PROTOBUF_NAMESPACE_ID::internal::ReadVarint32(&ptr); + ::int32_t val = ::PROTOBUF_NAMESPACE_ID::internal::ReadVarint32(&ptr); CHK_(ptr); if (PROTOBUF_PREDICT_TRUE(::PROTOBUF_NAMESPACE_ID::FieldOptions_CType_IsValid(static_cast(val)))) { _internal_set_ctype(static_cast<::PROTOBUF_NAMESPACE_ID::FieldOptions_CType>(val)); @@ -8341,7 +8341,7 @@ const char* FieldOptions::_InternalParse(const char* ptr, ::_pbi::ParseContext* // optional .google.protobuf.FieldOptions.JSType jstype = 6 [default = JS_NORMAL]; case 6: if (PROTOBUF_PREDICT_TRUE(static_cast<::uint8_t>(tag) == 48)) { - ::uint32_t val = ::PROTOBUF_NAMESPACE_ID::internal::ReadVarint32(&ptr); + ::int32_t val = ::PROTOBUF_NAMESPACE_ID::internal::ReadVarint32(&ptr); CHK_(ptr); if (PROTOBUF_PREDICT_TRUE(::PROTOBUF_NAMESPACE_ID::FieldOptions_JSType_IsValid(static_cast(val)))) { _internal_set_jstype(static_cast<::PROTOBUF_NAMESPACE_ID::FieldOptions_JSType>(val)); @@ -8385,7 +8385,7 @@ const char* FieldOptions::_InternalParse(const char* ptr, ::_pbi::ParseContext* // optional .google.protobuf.FieldOptions.OptionRetention retention = 17; case 17: if (PROTOBUF_PREDICT_TRUE(static_cast<::uint8_t>(tag) == 136)) { - ::uint32_t val = ::PROTOBUF_NAMESPACE_ID::internal::ReadVarint32(&ptr); + ::int32_t val = ::PROTOBUF_NAMESPACE_ID::internal::ReadVarint32(&ptr); CHK_(ptr); if (PROTOBUF_PREDICT_TRUE(::PROTOBUF_NAMESPACE_ID::FieldOptions_OptionRetention_IsValid(static_cast(val)))) { _internal_set_retention(static_cast<::PROTOBUF_NAMESPACE_ID::FieldOptions_OptionRetention>(val)); @@ -8399,7 +8399,7 @@ const char* FieldOptions::_InternalParse(const char* ptr, ::_pbi::ParseContext* // optional .google.protobuf.FieldOptions.OptionTargetType target = 18; case 18: if (PROTOBUF_PREDICT_TRUE(static_cast<::uint8_t>(tag) == 144)) { - ::uint32_t val = ::PROTOBUF_NAMESPACE_ID::internal::ReadVarint32(&ptr); + ::int32_t val = ::PROTOBUF_NAMESPACE_ID::internal::ReadVarint32(&ptr); CHK_(ptr); if (PROTOBUF_PREDICT_TRUE(::PROTOBUF_NAMESPACE_ID::FieldOptions_OptionTargetType_IsValid(static_cast(val)))) { _internal_set_target(static_cast<::PROTOBUF_NAMESPACE_ID::FieldOptions_OptionTargetType>(val)); @@ -9877,7 +9877,7 @@ const char* MethodOptions::_InternalParse(const char* ptr, ::_pbi::ParseContext* // optional .google.protobuf.MethodOptions.IdempotencyLevel idempotency_level = 34 [default = IDEMPOTENCY_UNKNOWN]; case 34: if (PROTOBUF_PREDICT_TRUE(static_cast<::uint8_t>(tag) == 16)) { - ::uint32_t val = ::PROTOBUF_NAMESPACE_ID::internal::ReadVarint32(&ptr); + ::int32_t val = ::PROTOBUF_NAMESPACE_ID::internal::ReadVarint32(&ptr); CHK_(ptr); if (PROTOBUF_PREDICT_TRUE(::PROTOBUF_NAMESPACE_ID::MethodOptions_IdempotencyLevel_IsValid(static_cast(val)))) { _internal_set_idempotency_level(static_cast<::PROTOBUF_NAMESPACE_ID::MethodOptions_IdempotencyLevel>(val)); @@ -11581,7 +11581,7 @@ const char* GeneratedCodeInfo_Annotation::_InternalParse(const char* ptr, ::_pbi // optional .google.protobuf.GeneratedCodeInfo.Annotation.Semantic semantic = 5; case 5: if (PROTOBUF_PREDICT_TRUE(static_cast<::uint8_t>(tag) == 40)) { - ::uint32_t val = ::PROTOBUF_NAMESPACE_ID::internal::ReadVarint32(&ptr); + ::int32_t val = ::PROTOBUF_NAMESPACE_ID::internal::ReadVarint32(&ptr); CHK_(ptr); if (PROTOBUF_PREDICT_TRUE(::PROTOBUF_NAMESPACE_ID::GeneratedCodeInfo_Annotation_Semantic_IsValid(static_cast(val)))) { _internal_set_semantic(static_cast<::PROTOBUF_NAMESPACE_ID::GeneratedCodeInfo_Annotation_Semantic>(val)); diff --git a/src/google/protobuf/extension_set_inl.h b/src/google/protobuf/extension_set_inl.h index bd01e95044..4554d95083 100644 --- a/src/google/protobuf/extension_set_inl.h +++ b/src/google/protobuf/extension_set_inl.h @@ -141,14 +141,14 @@ const char* ExtensionSet::ParseFieldWithExtensionInfo( #undef HANDLE_FIXED_TYPE case WireFormatLite::TYPE_ENUM: { - uint64_t val; - ptr = VarintParse(ptr, &val); + uint64_t tmp; + ptr = VarintParse(ptr, &tmp); GOOGLE_PROTOBUF_PARSER_ASSERT(ptr); - int value = val; + int value = tmp; if (!extension.enum_validity_check.func( extension.enum_validity_check.arg, value)) { - WriteVarint(number, val, metadata->mutable_unknown_fields()); + WriteVarint(number, value, metadata->mutable_unknown_fields()); } else if (extension.is_repeated) { AddEnum(number, WireFormatLite::TYPE_ENUM, extension.is_packed, value, extension.descriptor); diff --git a/src/google/protobuf/generated_message_tctable_impl.h b/src/google/protobuf/generated_message_tctable_impl.h index 35e5ec037c..541b475090 100644 --- a/src/google/protobuf/generated_message_tctable_impl.h +++ b/src/google/protobuf/generated_message_tctable_impl.h @@ -655,6 +655,7 @@ class PROTOBUF_EXPORT TcParser final { static const char* Error(PROTOBUF_TC_PARAM_NO_DATA_DECL); static const char* FastUnknownEnumFallback(PROTOBUF_TC_PARAM_DECL); + static const char* MpUnknownEnumFallback(PROTOBUF_TC_PARAM_DECL); class ScopedArenaSwap; @@ -763,8 +764,8 @@ class PROTOBUF_EXPORT TcParser final { const char* ptr, Arena* arena, SerialArena* serial_arena, ParseContext* ctx, RepeatedPtrField& field); - static void UnknownPackedEnum(MessageLite* msg, const TcParseTableBase* table, - uint32_t tag, int32_t enum_value); + static void AddUnknownEnum(MessageLite* msg, const TcParseTableBase* table, + uint32_t tag, int32_t enum_value); // Mini field lookup: static const TcParseTableBase::FieldEntry* FindFieldEntry( diff --git a/src/google/protobuf/generated_message_tctable_lite.cc b/src/google/protobuf/generated_message_tctable_lite.cc index 1455a9e286..4a8e8275d7 100644 --- a/src/google/protobuf/generated_message_tctable_lite.cc +++ b/src/google/protobuf/generated_message_tctable_lite.cc @@ -1224,16 +1224,34 @@ PROTOBUF_NOINLINE const char* TcParser::FastZ64P2(PROTOBUF_TC_PARAM_DECL) { PROTOBUF_NOINLINE const char* TcParser::FastUnknownEnumFallback( PROTOBUF_TC_PARAM_DECL) { - // If we know we want to put this field directly into the unknown field set, - // then we can skip the call to MiniParse and directly call table->fallback. - // However, we first have to update `data` to contain the decoded tag. + // Skip MiniParse/fallback and insert the element directly into the unknown + // field set. We also normalize the value into an int32 as we do for known + // enum values. uint32_t tag; ptr = ReadTag(ptr, &tag); if (PROTOBUF_PREDICT_FALSE(ptr == nullptr)) { PROTOBUF_MUSTTAIL return Error(PROTOBUF_TC_PARAM_NO_DATA_PASS); } - data.data = tag; - PROTOBUF_MUSTTAIL return table->fallback(PROTOBUF_TC_PARAM_PASS); + uint64_t tmp; + ptr = ParseVarint(ptr, &tmp); + if (PROTOBUF_PREDICT_FALSE(ptr == nullptr)) { + PROTOBUF_MUSTTAIL return Error(PROTOBUF_TC_PARAM_NO_DATA_PASS); + } + AddUnknownEnum(msg, table, tag, static_cast(tmp)); + PROTOBUF_MUSTTAIL return ToTagDispatch(PROTOBUF_TC_PARAM_NO_DATA_PASS); +} + +PROTOBUF_NOINLINE const char* TcParser::MpUnknownEnumFallback( + PROTOBUF_TC_PARAM_DECL) { + // Like FastUnknownEnumFallback, but with the Mp ABI. + uint32_t tag = data.tag(); + uint64_t tmp; + ptr = ParseVarint(ptr, &tmp); + if (PROTOBUF_PREDICT_FALSE(ptr == nullptr)) { + PROTOBUF_MUSTTAIL return Error(PROTOBUF_TC_PARAM_NO_DATA_PASS); + } + AddUnknownEnum(msg, table, tag, static_cast(tmp)); + PROTOBUF_MUSTTAIL return ToTagDispatch(PROTOBUF_TC_PARAM_NO_DATA_PASS); } template @@ -1325,9 +1343,10 @@ const TcParser::UnknownFieldOps& TcParser::GetUnknownFieldOps( return *reinterpret_cast(ptr); } -PROTOBUF_NOINLINE void TcParser::UnknownPackedEnum( - MessageLite* msg, const TcParseTableBase* table, uint32_t tag, - int32_t enum_value) { +PROTOBUF_NOINLINE void TcParser::AddUnknownEnum(MessageLite* msg, + const TcParseTableBase* table, + uint32_t tag, + int32_t enum_value) { GetUnknownFieldOps(table).write_varint(msg, tag >> 3, enum_value); } @@ -1351,7 +1370,7 @@ const char* TcParser::PackedEnum(PROTOBUF_TC_PARAM_DECL) { const TcParseTableBase::FieldAux aux = *table->field_aux(data.aux_idx()); return ctx->ReadPackedVarint(ptr, [=](int32_t value) { if (!EnumIsValidAux(value, xform_val, aux)) { - UnknownPackedEnum(msg, table, FastDecodeTag(saved_tag), value); + AddUnknownEnum(msg, table, FastDecodeTag(saved_tag), value); } else { field->Add(value); } @@ -1498,7 +1517,7 @@ const char* TcParser::PackedEnumSmallRange(PROTOBUF_TC_PARAM_DECL) { return ctx->ReadPackedVarint(ptr, [=](int32_t v) { if (PROTOBUF_PREDICT_FALSE(min > v || v > max)) { - UnknownPackedEnum(msg, table, FastDecodeTag(saved_tag), v); + AddUnknownEnum(msg, table, FastDecodeTag(saved_tag), v); } else { field->Add(v); } @@ -2021,7 +2040,7 @@ PROTOBUF_NOINLINE const char* TcParser::MpVarint(PROTOBUF_TC_PARAM_DECL) { if (is_validated_enum) { if (!EnumIsValidAux(tmp, xform_val, *table->field_aux(&entry))) { ptr = ptr2; - PROTOBUF_MUSTTAIL return table->fallback(PROTOBUF_TC_PARAM_PASS); + PROTOBUF_MUSTTAIL return MpUnknownEnumFallback(PROTOBUF_TC_PARAM_PASS); } } else if (is_zigzag) { tmp = WireFormatLite::ZigZagDecode32(static_cast(tmp)); @@ -2099,7 +2118,8 @@ PROTOBUF_NOINLINE const char* TcParser::MpRepeatedVarint( if (is_validated_enum) { if (!EnumIsValidAux(tmp, xform_val, *table->field_aux(&entry))) { ptr = ptr2; - PROTOBUF_MUSTTAIL return table->fallback(PROTOBUF_TC_PARAM_PASS); + PROTOBUF_MUSTTAIL return MpUnknownEnumFallback( + PROTOBUF_TC_PARAM_PASS); } } else if (is_zigzag) { tmp = WireFormatLite::ZigZagDecode32(tmp); @@ -2163,7 +2183,7 @@ PROTOBUF_NOINLINE const char* TcParser::MpPackedVarint(PROTOBUF_TC_PARAM_DECL) { const TcParseTableBase::FieldAux aux = *table->field_aux(entry.aux_idx); return ctx->ReadPackedVarint(ptr, [=](int32_t value) { if (!EnumIsValidAux(value, xform_val, aux)) { - UnknownPackedEnum(msg, table, data.tag(), value); + AddUnknownEnum(msg, table, data.tag(), value); } else { field->Add(value); } diff --git a/src/google/protobuf/message_unittest.inc b/src/google/protobuf/message_unittest.inc index 8cfe8c1d5a..f5e330e0a4 100644 --- a/src/google/protobuf/message_unittest.inc +++ b/src/google/protobuf/message_unittest.inc @@ -61,6 +61,7 @@ #include "absl/strings/substitute.h" #include "google/protobuf/arena.h" #include "google/protobuf/descriptor.h" +#include "google/protobuf/dynamic_message.h" #include "google/protobuf/generated_message_reflection.h" #include "google/protobuf/generated_message_tctable_impl.h" #include "google/protobuf/io/coded_stream.h" @@ -1310,18 +1311,62 @@ std::string EncodeOverlongEnum(int number, bool use_packed) { } } +std::string EncodeInt32Value(int number, int32_t value, + int non_canonical_bytes) { + uint8_t buf[100]; + uint8_t* p = buf; + + p = internal::WireFormatLite::WriteInt32ToArray(number, value, p); + p = AddNonCanonicalBytes(SkipTag(buf), p, non_canonical_bytes); + return std::string(buf, p); +} + +std::string EncodeInt64Value(int number, int64_t value, int non_canonical_bytes, + bool use_packed = false) { + uint8_t buf[100]; + uint8_t* p = buf; + + if (use_packed) { + p = internal::WireFormatLite::WriteInt64NoTagToArray(value, p); + p = AddNonCanonicalBytes(buf, p, non_canonical_bytes); + + std::string payload(buf, p); + p = buf; + p = internal::WireFormatLite::WriteStringToArray(number, payload, p); + return std::string(buf, p); + + } else { + p = internal::WireFormatLite::WriteInt64ToArray(number, value, p); + p = AddNonCanonicalBytes(SkipTag(buf), p, non_canonical_bytes); + return std::string(buf, p); + } +} + std::string EncodeOtherField() { UNITTEST::EnumParseTester obj; obj.set_other_field(1); return obj.SerializeAsString(); } +template +static std::vector GetFields() { + auto* descriptor = T::descriptor(); + std::vector fields; + for (int i = 0; i < descriptor->field_count(); ++i) { + fields.push_back(descriptor->field(i)); + } + for (int i = 0; i < descriptor->extension_count(); ++i) { + fields.push_back(descriptor->extension(i)); + } + return fields; +} + TEST(MESSAGE_TEST_NAME, TestEnumParsers) { UNITTEST::EnumParseTester obj; const auto other_field = EncodeOtherField(); - // Encode a boolean field for many different cases and verify that it can be + // Encode an enum field for many different cases and verify that it can be // parsed as expected. // There are: // - optional/repeated/packed fields @@ -1331,6 +1376,9 @@ TEST(MESSAGE_TEST_NAME, TestEnumParsers) { // - label combinations to trigger different parsers: sequential, small // sequential, non-validated. + const std::vector fields = + GetFields(); + constexpr int kInvalidValue = 0x900913; auto* ref = obj.GetReflection(); auto* descriptor = obj.descriptor(); @@ -1347,8 +1395,7 @@ TEST(MESSAGE_TEST_NAME, TestEnumParsers) { continue; } SCOPED_TRACE(add_garbage_bits); - for (int i = 0; i < descriptor->field_count(); ++i) { - const auto* field = descriptor->field(i); + for (auto field : fields) { if (field->name() == "other_field") continue; if (!field->is_repeated() && use_packed) continue; SCOPED_TRACE(field->full_name()); @@ -1421,6 +1468,52 @@ TEST(MESSAGE_TEST_NAME, TestEnumParsers) { } } +TEST(MESSAGE_TEST_NAME, TestEnumParserForUnknownEnumValue) { + DynamicMessageFactory factory; + std::unique_ptr dynamic( + factory.GetPrototype(UNITTEST::EnumParseTester::descriptor())->New()); + + UNITTEST::EnumParseTester non_dynamic; + + // For unknown enum values, for consistency we must include the + // int32_t enum value in the unknown field set, which might not be exactly the + // same as the input. + auto* descriptor = non_dynamic.descriptor(); + + const std::vector fields = + GetFields(); + + for (bool use_dynamic : {false, true}) { + SCOPED_TRACE(use_dynamic); + for (auto field : fields) { + if (field->name() == "other_field") continue; + SCOPED_TRACE(field->full_name()); + for (bool use_packed : {false, true}) { + SCOPED_TRACE(use_packed); + if (!field->is_repeated() && use_packed) continue; + + // -2 is an invalid enum value on all the tests here. + // We will encode -2 as a positive int64 that is equivalent to + // int32_t{-2} when truncated. + constexpr int64_t minus_2_non_canonical = + static_cast(static_cast(int32_t{-2})); + static_assert(minus_2_non_canonical != -2, ""); + std::string encoded = EncodeInt64Value( + field->number(), minus_2_non_canonical, 0, use_packed); + + auto& obj = use_dynamic ? *dynamic : non_dynamic; + ASSERT_TRUE(obj.ParseFromString(encoded)); + + auto& unknown = obj.GetReflection()->GetUnknownFields(obj); + ASSERT_EQ(unknown.field_count(), 1); + EXPECT_EQ(unknown.field(0).number(), field->number()); + EXPECT_EQ(unknown.field(0).type(), unknown.field(0).TYPE_VARINT); + EXPECT_EQ(unknown.field(0).varint(), int64_t{-2}); + } + } + } +} + std::string EncodeBoolValue(int number, bool value, int non_canonical_bytes) { uint8_t buf[100]; uint8_t* p = buf; @@ -1443,6 +1536,9 @@ TEST(MESSAGE_TEST_NAME, TestBoolParsers) { // - canonical and non-canonical encodings of the varint // - last vs not last field + const std::vector fields = + GetFields(); + auto* ref = obj.GetReflection(); auto* descriptor = obj.descriptor(); for (bool use_tail_field : {false, true}) { @@ -1456,8 +1552,7 @@ TEST(MESSAGE_TEST_NAME, TestBoolParsers) { continue; } SCOPED_TRACE(add_garbage_bits); - for (int i = 0; i < descriptor->field_count(); ++i) { - const auto* field = descriptor->field(i); + for (auto field : fields) { if (field->name() == "other_field") continue; SCOPED_TRACE(field->full_name()); for (bool value : {false, true}) { @@ -1492,16 +1587,6 @@ TEST(MESSAGE_TEST_NAME, TestBoolParsers) { } } -std::string EncodeInt32Value(int number, int32_t value, - int non_canonical_bytes) { - uint8_t buf[100]; - uint8_t* p = buf; - - p = internal::WireFormatLite::WriteInt32ToArray(number, value, p); - p = AddNonCanonicalBytes(SkipTag(buf), p, non_canonical_bytes); - return std::string(buf, p); -} - TEST(MESSAGE_TEST_NAME, TestInt32Parsers) { UNITTEST::Int32ParseTester obj; @@ -1515,6 +1600,9 @@ TEST(MESSAGE_TEST_NAME, TestInt32Parsers) { // - canonical and non-canonical encodings of the varint // - last vs not last field + const std::vector fields = + GetFields(); + auto* ref = obj.GetReflection(); auto* descriptor = obj.descriptor(); for (bool use_tail_field : {false, true}) { @@ -1528,8 +1616,7 @@ TEST(MESSAGE_TEST_NAME, TestInt32Parsers) { continue; } SCOPED_TRACE(add_garbage_bits); - for (int i = 0; i < descriptor->field_count(); ++i) { - const auto* field = descriptor->field(i); + for (auto field : fields) { if (field->name() == "other_field") continue; SCOPED_TRACE(field->full_name()); for (int32_t value : {1, 0, -1, (std::numeric_limits::min)(), @@ -1565,16 +1652,6 @@ TEST(MESSAGE_TEST_NAME, TestInt32Parsers) { } } -std::string EncodeInt64Value(int number, int64_t value, - int non_canonical_bytes) { - uint8_t buf[100]; - uint8_t* p = buf; - - p = internal::WireFormatLite::WriteInt64ToArray(number, value, p); - p = AddNonCanonicalBytes(SkipTag(buf), p, non_canonical_bytes); - return std::string(buf, p); -} - TEST(MESSAGE_TEST_NAME, TestInt64Parsers) { UNITTEST::Int64ParseTester obj; @@ -1588,6 +1665,9 @@ TEST(MESSAGE_TEST_NAME, TestInt64Parsers) { // - canonical and non-canonical encodings of the varint // - last vs not last field + const std::vector fields = + GetFields(); + auto* ref = obj.GetReflection(); auto* descriptor = obj.descriptor(); for (bool use_tail_field : {false, true}) { @@ -1601,8 +1681,7 @@ TEST(MESSAGE_TEST_NAME, TestInt64Parsers) { continue; } SCOPED_TRACE(add_garbage_bits); - for (int i = 0; i < descriptor->field_count(); ++i) { - const auto* field = descriptor->field(i); + for (auto field : fields) { if (field->name() == "other_field") continue; SCOPED_TRACE(field->full_name()); for (int64_t value : {int64_t{1}, int64_t{0}, int64_t{-1}, @@ -1748,6 +1827,9 @@ TEST(MESSAGE_TEST_NAME, TestRepeatedStringParsers) { const auto* const descriptor = UNITTEST::StringParseTester::descriptor(); + const std::vector fields = + GetFields(); + static const size_t sso_capacity = std::string().capacity(); if (sso_capacity == 0) GTEST_SKIP(); // SSO, !SSO, and off-by-one just in case @@ -1755,8 +1837,7 @@ TEST(MESSAGE_TEST_NAME, TestRepeatedStringParsers) { {sso_capacity - 1, sso_capacity, sso_capacity + 1, sso_capacity + 2}) { SCOPED_TRACE(size); const std::string value = sample.substr(0, size); - for (int i = 0; i < descriptor->field_count(); ++i) { - const auto* field = descriptor->field(i); + for (auto field : fields) { SCOPED_TRACE(field->full_name()); const auto encoded = EncodeStringValue(field->number(), sample) + EncodeStringValue(field->number(), value); diff --git a/src/google/protobuf/parse_context.h b/src/google/protobuf/parse_context.h index 0169fe7828..e3f8ae0c67 100644 --- a/src/google/protobuf/parse_context.h +++ b/src/google/protobuf/parse_context.h @@ -1170,7 +1170,7 @@ PROTOBUF_NODISCARD const char* PackedEnumParser(void* object, const char* ptr, InternalMetadata* metadata, int field_num) { return ctx->ReadPackedVarint( - ptr, [object, is_valid, metadata, field_num](uint64_t val) { + ptr, [object, is_valid, metadata, field_num](int32_t val) { if (is_valid(val)) { static_cast*>(object)->Add(val); } else { @@ -1185,7 +1185,7 @@ PROTOBUF_NODISCARD const char* PackedEnumParserArg( bool (*is_valid)(const void*, int), const void* data, InternalMetadata* metadata, int field_num) { return ctx->ReadPackedVarint( - ptr, [object, is_valid, data, metadata, field_num](uint64_t val) { + ptr, [object, is_valid, data, metadata, field_num](int32_t val) { if (is_valid(data, val)) { static_cast*>(object)->Add(val); } else { diff --git a/src/google/protobuf/unittest.proto b/src/google/protobuf/unittest.proto index f6c57deb1b..108572e86f 100644 --- a/src/google/protobuf/unittest.proto +++ b/src/google/protobuf/unittest.proto @@ -1547,6 +1547,13 @@ message EnumParseTester { repeated Arbitrary packed_arbitrary_midfield = 1012 [packed = true]; repeated Arbitrary packed_arbitrary_hifield = 1000012 [packed = true]; + extensions 2000000 to max; + extend EnumParseTester { + optional Arbitrary optional_arbitrary_ext = 2000000; + repeated Arbitrary repeated_arbitrary_ext = 2000001; + repeated Arbitrary packed_arbitrary_ext = 2000002 [packed = true]; + } + // An arbitrary field we can append to to break the runs of repeated fields. optional int32 other_field = 99; } @@ -1564,6 +1571,13 @@ message BoolParseTester { repeated bool packed_bool_midfield = 1003 [packed = true]; repeated bool packed_bool_hifield = 1000003 [packed = true]; + extensions 2000000 to max; + extend BoolParseTester { + optional bool optional_bool_ext = 2000000; + repeated bool repeated_bool_ext = 2000001; + repeated bool packed_bool_ext = 2000002 [packed = true]; + } + // An arbitrary field we can append to to break the runs of repeated fields. optional int32 other_field = 99; } @@ -1579,6 +1593,13 @@ message Int32ParseTester { repeated int32 packed_int32_midfield = 1003 [packed = true]; repeated int32 packed_int32_hifield = 1000003 [packed = true]; + extensions 2000000 to max; + extend Int32ParseTester { + optional int32 optional_int32_ext = 2000000; + repeated int32 repeated_int32_ext = 2000001; + repeated int32 packed_int32_ext = 2000002 [packed = true]; + } + // An arbitrary field we can append to to break the runs of repeated fields. optional int32 other_field = 99; } @@ -1594,6 +1615,13 @@ message Int64ParseTester { repeated int64 packed_int64_midfield = 1003 [packed = true]; repeated int64 packed_int64_hifield = 1000003 [packed = true]; + extensions 2000000 to max; + extend Int64ParseTester { + optional int64 optional_int64_ext = 2000000; + repeated int64 repeated_int64_ext = 2000001; + repeated int64 packed_int64_ext = 2000002 [packed = true]; + } + // An arbitrary field we can append to to break the runs of repeated fields. optional int32 other_field = 99; } @@ -1617,6 +1645,12 @@ message StringParseTester { repeated string repeated_string_lowfield = 2; repeated string repeated_string_midfield = 1002; repeated string repeated_string_hifield = 1000002; + + extensions 2000000 to max; + extend StringParseTester { + optional string optional_string_ext = 2000000; + repeated string repeated_string_ext = 2000001; + } } message BadFieldNames{ diff --git a/src/google/protobuf/wire_format.cc b/src/google/protobuf/wire_format.cc index e0c18bb00c..fe62d1c622 100644 --- a/src/google/protobuf/wire_format.cc +++ b/src/google/protobuf/wire_format.cc @@ -874,7 +874,7 @@ const char* WireFormat::_InternalParseAndMergeField( ptr = internal::PackedEnumParser(rep_enum, ptr, ctx); } else { return ctx->ReadPackedVarint( - ptr, [rep_enum, field, reflection, msg](uint64_t val) { + ptr, [rep_enum, field, reflection, msg](int32_t val) { if (field->enum_type()->FindValueByNumber(val) != nullptr) { rep_enum->Add(val); } else {