Change the C++ parsers to be consistent on how they handle unknown enum values.

They now consistently parse the varint and then push the int32 representation into the unknown field set. Previously the behavior was sometimes pushing the original varint, sometimes pushing a uint32_t truncated value, and sometimes an int32_t truncated value.

PiperOrigin-RevId: 512140469
pull/11958/head
Protobuf Team Bot 2 years ago committed by Copybara-Service
parent 8ee8d3522b
commit de4ce2e2de
  1. 22
      src/google/protobuf/compiler/cpp/parse_function_generator.cc
  2. 18
      src/google/protobuf/descriptor.pb.cc
  3. 8
      src/google/protobuf/extension_set_inl.h
  4. 3
      src/google/protobuf/generated_message_tctable_impl.h
  5. 44
      src/google/protobuf/generated_message_tctable_lite.cc
  6. 143
      src/google/protobuf/message_unittest.inc
  7. 4
      src/google/protobuf/parse_context.h
  8. 34
      src/google/protobuf/unittest.proto
  9. 2
      src/google/protobuf/wire_format.cc

@ -216,17 +216,6 @@ void ParseFunctionGenerator::GenerateTailcallParseFunction(Formatter& format) {
"}\n\n"); "}\n\n");
} }
static bool NeedsUnknownEnumSupport(const Descriptor* descriptor) {
for (int i = 0; i < descriptor->field_count(); ++i) {
auto* field = descriptor->field(i);
if (field->is_repeated() && field->cpp_type() == field->CPPTYPE_ENUM &&
!internal::cpp::HasPreservingUnknownEnumSemantics(field)) {
return true;
}
}
return false;
}
void ParseFunctionGenerator::GenerateTailcallFallbackFunction( void ParseFunctionGenerator::GenerateTailcallFallbackFunction(
Formatter& format) { Formatter& format) {
ABSL_CHECK(should_generate_tctable()); ABSL_CHECK(should_generate_tctable());
@ -236,9 +225,8 @@ void ParseFunctionGenerator::GenerateTailcallFallbackFunction(
format.Indent(); format.Indent();
format("auto* typed_msg = static_cast<$classname$*>(msg);\n"); format("auto* typed_msg = static_cast<$classname$*>(msg);\n");
// If we need a side channel, generate the check to jump to the generic // Generate the check to jump to the generic handler to deal with the side
// handler to deal with the side channel data. // channel data.
if (NeedsUnknownEnumSupport(descriptor_)) {
format( format(
"if (PROTOBUF_PREDICT_FALSE(\n" "if (PROTOBUF_PREDICT_FALSE(\n"
" _pbi::TcParser::MustFallbackToGeneric(PROTOBUF_TC_PARAM_PASS))) " " _pbi::TcParser::MustFallbackToGeneric(PROTOBUF_TC_PARAM_PASS))) "
@ -246,11 +234,9 @@ void ParseFunctionGenerator::GenerateTailcallFallbackFunction(
" PROTOBUF_MUSTTAIL return " " PROTOBUF_MUSTTAIL return "
"::_pbi::TcParser::GenericFallback$1$(PROTOBUF_TC_PARAM_PASS);\n" "::_pbi::TcParser::GenericFallback$1$(PROTOBUF_TC_PARAM_PASS);\n"
"}\n", "}\n",
GetOptimizeFor(descriptor_->file(), options_) == GetOptimizeFor(descriptor_->file(), options_) == FileOptions::LITE_RUNTIME
FileOptions::LITE_RUNTIME
? "Lite" ? "Lite"
: ""); : "");
}
if (num_hasbits_ > 0) { if (num_hasbits_ > 0) {
// Sync hasbits // Sync hasbits
@ -1147,7 +1133,7 @@ void ParseFunctionGenerator::GenerateFieldBody(
format.Set("enum_type", format.Set("enum_type",
QualifiedClassName(field->enum_type(), options_)); QualifiedClassName(field->enum_type(), options_));
format( format(
"$uint32$ val = ::$proto_ns$::internal::ReadVarint32(&ptr);\n" "$int32$ val = ::$proto_ns$::internal::ReadVarint32(&ptr);\n"
"CHK_(ptr);\n"); "CHK_(ptr);\n");
if (!internal::cpp::HasPreservingUnknownEnumSemantics(field)) { if (!internal::cpp::HasPreservingUnknownEnumSemantics(field)) {
format( format(

@ -4252,7 +4252,7 @@ const char* FieldDescriptorProto::_InternalParse(const char* ptr, ::_pbi::ParseC
// optional .google.protobuf.FieldDescriptorProto.Label label = 4; // optional .google.protobuf.FieldDescriptorProto.Label label = 4;
case 4: case 4:
if (PROTOBUF_PREDICT_TRUE(static_cast<::uint8_t>(tag) == 32)) { if (PROTOBUF_PREDICT_TRUE(static_cast<::uint8_t>(tag) == 32)) {
::uint32_t val = ::PROTOBUF_NAMESPACE_ID::internal::ReadVarint32(&ptr); ::int32_t val = ::PROTOBUF_NAMESPACE_ID::internal::ReadVarint32(&ptr);
CHK_(ptr); CHK_(ptr);
if (PROTOBUF_PREDICT_TRUE(::PROTOBUF_NAMESPACE_ID::FieldDescriptorProto_Label_IsValid(static_cast<int>(val)))) { if (PROTOBUF_PREDICT_TRUE(::PROTOBUF_NAMESPACE_ID::FieldDescriptorProto_Label_IsValid(static_cast<int>(val)))) {
_internal_set_label(static_cast<::PROTOBUF_NAMESPACE_ID::FieldDescriptorProto_Label>(val)); _internal_set_label(static_cast<::PROTOBUF_NAMESPACE_ID::FieldDescriptorProto_Label>(val));
@ -4266,7 +4266,7 @@ const char* FieldDescriptorProto::_InternalParse(const char* ptr, ::_pbi::ParseC
// optional .google.protobuf.FieldDescriptorProto.Type type = 5; // optional .google.protobuf.FieldDescriptorProto.Type type = 5;
case 5: case 5:
if (PROTOBUF_PREDICT_TRUE(static_cast<::uint8_t>(tag) == 40)) { if (PROTOBUF_PREDICT_TRUE(static_cast<::uint8_t>(tag) == 40)) {
::uint32_t val = ::PROTOBUF_NAMESPACE_ID::internal::ReadVarint32(&ptr); ::int32_t val = ::PROTOBUF_NAMESPACE_ID::internal::ReadVarint32(&ptr);
CHK_(ptr); CHK_(ptr);
if (PROTOBUF_PREDICT_TRUE(::PROTOBUF_NAMESPACE_ID::FieldDescriptorProto_Type_IsValid(static_cast<int>(val)))) { if (PROTOBUF_PREDICT_TRUE(::PROTOBUF_NAMESPACE_ID::FieldDescriptorProto_Type_IsValid(static_cast<int>(val)))) {
_internal_set_type(static_cast<::PROTOBUF_NAMESPACE_ID::FieldDescriptorProto_Type>(val)); _internal_set_type(static_cast<::PROTOBUF_NAMESPACE_ID::FieldDescriptorProto_Type>(val));
@ -7026,7 +7026,7 @@ const char* FileOptions::_InternalParse(const char* ptr, ::_pbi::ParseContext* c
// optional .google.protobuf.FileOptions.OptimizeMode optimize_for = 9 [default = SPEED]; // optional .google.protobuf.FileOptions.OptimizeMode optimize_for = 9 [default = SPEED];
case 9: case 9:
if (PROTOBUF_PREDICT_TRUE(static_cast<::uint8_t>(tag) == 72)) { if (PROTOBUF_PREDICT_TRUE(static_cast<::uint8_t>(tag) == 72)) {
::uint32_t val = ::PROTOBUF_NAMESPACE_ID::internal::ReadVarint32(&ptr); ::int32_t val = ::PROTOBUF_NAMESPACE_ID::internal::ReadVarint32(&ptr);
CHK_(ptr); CHK_(ptr);
if (PROTOBUF_PREDICT_TRUE(::PROTOBUF_NAMESPACE_ID::FileOptions_OptimizeMode_IsValid(static_cast<int>(val)))) { if (PROTOBUF_PREDICT_TRUE(::PROTOBUF_NAMESPACE_ID::FileOptions_OptimizeMode_IsValid(static_cast<int>(val)))) {
_internal_set_optimize_for(static_cast<::PROTOBUF_NAMESPACE_ID::FileOptions_OptimizeMode>(val)); _internal_set_optimize_for(static_cast<::PROTOBUF_NAMESPACE_ID::FileOptions_OptimizeMode>(val));
@ -8297,7 +8297,7 @@ const char* FieldOptions::_InternalParse(const char* ptr, ::_pbi::ParseContext*
// optional .google.protobuf.FieldOptions.CType ctype = 1 [default = STRING]; // optional .google.protobuf.FieldOptions.CType ctype = 1 [default = STRING];
case 1: case 1:
if (PROTOBUF_PREDICT_TRUE(static_cast<::uint8_t>(tag) == 8)) { if (PROTOBUF_PREDICT_TRUE(static_cast<::uint8_t>(tag) == 8)) {
::uint32_t val = ::PROTOBUF_NAMESPACE_ID::internal::ReadVarint32(&ptr); ::int32_t val = ::PROTOBUF_NAMESPACE_ID::internal::ReadVarint32(&ptr);
CHK_(ptr); CHK_(ptr);
if (PROTOBUF_PREDICT_TRUE(::PROTOBUF_NAMESPACE_ID::FieldOptions_CType_IsValid(static_cast<int>(val)))) { if (PROTOBUF_PREDICT_TRUE(::PROTOBUF_NAMESPACE_ID::FieldOptions_CType_IsValid(static_cast<int>(val)))) {
_internal_set_ctype(static_cast<::PROTOBUF_NAMESPACE_ID::FieldOptions_CType>(val)); _internal_set_ctype(static_cast<::PROTOBUF_NAMESPACE_ID::FieldOptions_CType>(val));
@ -8341,7 +8341,7 @@ const char* FieldOptions::_InternalParse(const char* ptr, ::_pbi::ParseContext*
// optional .google.protobuf.FieldOptions.JSType jstype = 6 [default = JS_NORMAL]; // optional .google.protobuf.FieldOptions.JSType jstype = 6 [default = JS_NORMAL];
case 6: case 6:
if (PROTOBUF_PREDICT_TRUE(static_cast<::uint8_t>(tag) == 48)) { if (PROTOBUF_PREDICT_TRUE(static_cast<::uint8_t>(tag) == 48)) {
::uint32_t val = ::PROTOBUF_NAMESPACE_ID::internal::ReadVarint32(&ptr); ::int32_t val = ::PROTOBUF_NAMESPACE_ID::internal::ReadVarint32(&ptr);
CHK_(ptr); CHK_(ptr);
if (PROTOBUF_PREDICT_TRUE(::PROTOBUF_NAMESPACE_ID::FieldOptions_JSType_IsValid(static_cast<int>(val)))) { if (PROTOBUF_PREDICT_TRUE(::PROTOBUF_NAMESPACE_ID::FieldOptions_JSType_IsValid(static_cast<int>(val)))) {
_internal_set_jstype(static_cast<::PROTOBUF_NAMESPACE_ID::FieldOptions_JSType>(val)); _internal_set_jstype(static_cast<::PROTOBUF_NAMESPACE_ID::FieldOptions_JSType>(val));
@ -8385,7 +8385,7 @@ const char* FieldOptions::_InternalParse(const char* ptr, ::_pbi::ParseContext*
// optional .google.protobuf.FieldOptions.OptionRetention retention = 17; // optional .google.protobuf.FieldOptions.OptionRetention retention = 17;
case 17: case 17:
if (PROTOBUF_PREDICT_TRUE(static_cast<::uint8_t>(tag) == 136)) { if (PROTOBUF_PREDICT_TRUE(static_cast<::uint8_t>(tag) == 136)) {
::uint32_t val = ::PROTOBUF_NAMESPACE_ID::internal::ReadVarint32(&ptr); ::int32_t val = ::PROTOBUF_NAMESPACE_ID::internal::ReadVarint32(&ptr);
CHK_(ptr); CHK_(ptr);
if (PROTOBUF_PREDICT_TRUE(::PROTOBUF_NAMESPACE_ID::FieldOptions_OptionRetention_IsValid(static_cast<int>(val)))) { if (PROTOBUF_PREDICT_TRUE(::PROTOBUF_NAMESPACE_ID::FieldOptions_OptionRetention_IsValid(static_cast<int>(val)))) {
_internal_set_retention(static_cast<::PROTOBUF_NAMESPACE_ID::FieldOptions_OptionRetention>(val)); _internal_set_retention(static_cast<::PROTOBUF_NAMESPACE_ID::FieldOptions_OptionRetention>(val));
@ -8399,7 +8399,7 @@ const char* FieldOptions::_InternalParse(const char* ptr, ::_pbi::ParseContext*
// optional .google.protobuf.FieldOptions.OptionTargetType target = 18; // optional .google.protobuf.FieldOptions.OptionTargetType target = 18;
case 18: case 18:
if (PROTOBUF_PREDICT_TRUE(static_cast<::uint8_t>(tag) == 144)) { if (PROTOBUF_PREDICT_TRUE(static_cast<::uint8_t>(tag) == 144)) {
::uint32_t val = ::PROTOBUF_NAMESPACE_ID::internal::ReadVarint32(&ptr); ::int32_t val = ::PROTOBUF_NAMESPACE_ID::internal::ReadVarint32(&ptr);
CHK_(ptr); CHK_(ptr);
if (PROTOBUF_PREDICT_TRUE(::PROTOBUF_NAMESPACE_ID::FieldOptions_OptionTargetType_IsValid(static_cast<int>(val)))) { if (PROTOBUF_PREDICT_TRUE(::PROTOBUF_NAMESPACE_ID::FieldOptions_OptionTargetType_IsValid(static_cast<int>(val)))) {
_internal_set_target(static_cast<::PROTOBUF_NAMESPACE_ID::FieldOptions_OptionTargetType>(val)); _internal_set_target(static_cast<::PROTOBUF_NAMESPACE_ID::FieldOptions_OptionTargetType>(val));
@ -9877,7 +9877,7 @@ const char* MethodOptions::_InternalParse(const char* ptr, ::_pbi::ParseContext*
// optional .google.protobuf.MethodOptions.IdempotencyLevel idempotency_level = 34 [default = IDEMPOTENCY_UNKNOWN]; // optional .google.protobuf.MethodOptions.IdempotencyLevel idempotency_level = 34 [default = IDEMPOTENCY_UNKNOWN];
case 34: case 34:
if (PROTOBUF_PREDICT_TRUE(static_cast<::uint8_t>(tag) == 16)) { if (PROTOBUF_PREDICT_TRUE(static_cast<::uint8_t>(tag) == 16)) {
::uint32_t val = ::PROTOBUF_NAMESPACE_ID::internal::ReadVarint32(&ptr); ::int32_t val = ::PROTOBUF_NAMESPACE_ID::internal::ReadVarint32(&ptr);
CHK_(ptr); CHK_(ptr);
if (PROTOBUF_PREDICT_TRUE(::PROTOBUF_NAMESPACE_ID::MethodOptions_IdempotencyLevel_IsValid(static_cast<int>(val)))) { if (PROTOBUF_PREDICT_TRUE(::PROTOBUF_NAMESPACE_ID::MethodOptions_IdempotencyLevel_IsValid(static_cast<int>(val)))) {
_internal_set_idempotency_level(static_cast<::PROTOBUF_NAMESPACE_ID::MethodOptions_IdempotencyLevel>(val)); _internal_set_idempotency_level(static_cast<::PROTOBUF_NAMESPACE_ID::MethodOptions_IdempotencyLevel>(val));
@ -11581,7 +11581,7 @@ const char* GeneratedCodeInfo_Annotation::_InternalParse(const char* ptr, ::_pbi
// optional .google.protobuf.GeneratedCodeInfo.Annotation.Semantic semantic = 5; // optional .google.protobuf.GeneratedCodeInfo.Annotation.Semantic semantic = 5;
case 5: case 5:
if (PROTOBUF_PREDICT_TRUE(static_cast<::uint8_t>(tag) == 40)) { if (PROTOBUF_PREDICT_TRUE(static_cast<::uint8_t>(tag) == 40)) {
::uint32_t val = ::PROTOBUF_NAMESPACE_ID::internal::ReadVarint32(&ptr); ::int32_t val = ::PROTOBUF_NAMESPACE_ID::internal::ReadVarint32(&ptr);
CHK_(ptr); CHK_(ptr);
if (PROTOBUF_PREDICT_TRUE(::PROTOBUF_NAMESPACE_ID::GeneratedCodeInfo_Annotation_Semantic_IsValid(static_cast<int>(val)))) { if (PROTOBUF_PREDICT_TRUE(::PROTOBUF_NAMESPACE_ID::GeneratedCodeInfo_Annotation_Semantic_IsValid(static_cast<int>(val)))) {
_internal_set_semantic(static_cast<::PROTOBUF_NAMESPACE_ID::GeneratedCodeInfo_Annotation_Semantic>(val)); _internal_set_semantic(static_cast<::PROTOBUF_NAMESPACE_ID::GeneratedCodeInfo_Annotation_Semantic>(val));

@ -141,14 +141,14 @@ const char* ExtensionSet::ParseFieldWithExtensionInfo(
#undef HANDLE_FIXED_TYPE #undef HANDLE_FIXED_TYPE
case WireFormatLite::TYPE_ENUM: { case WireFormatLite::TYPE_ENUM: {
uint64_t val; uint64_t tmp;
ptr = VarintParse(ptr, &val); ptr = VarintParse(ptr, &tmp);
GOOGLE_PROTOBUF_PARSER_ASSERT(ptr); GOOGLE_PROTOBUF_PARSER_ASSERT(ptr);
int value = val; int value = tmp;
if (!extension.enum_validity_check.func( if (!extension.enum_validity_check.func(
extension.enum_validity_check.arg, value)) { extension.enum_validity_check.arg, value)) {
WriteVarint(number, val, metadata->mutable_unknown_fields<T>()); WriteVarint(number, value, metadata->mutable_unknown_fields<T>());
} else if (extension.is_repeated) { } else if (extension.is_repeated) {
AddEnum(number, WireFormatLite::TYPE_ENUM, extension.is_packed, value, AddEnum(number, WireFormatLite::TYPE_ENUM, extension.is_packed, value,
extension.descriptor); extension.descriptor);

@ -655,6 +655,7 @@ class PROTOBUF_EXPORT TcParser final {
static const char* Error(PROTOBUF_TC_PARAM_NO_DATA_DECL); static const char* Error(PROTOBUF_TC_PARAM_NO_DATA_DECL);
static const char* FastUnknownEnumFallback(PROTOBUF_TC_PARAM_DECL); static const char* FastUnknownEnumFallback(PROTOBUF_TC_PARAM_DECL);
static const char* MpUnknownEnumFallback(PROTOBUF_TC_PARAM_DECL);
class ScopedArenaSwap; class ScopedArenaSwap;
@ -763,7 +764,7 @@ class PROTOBUF_EXPORT TcParser final {
const char* ptr, Arena* arena, SerialArena* serial_arena, const char* ptr, Arena* arena, SerialArena* serial_arena,
ParseContext* ctx, RepeatedPtrField<std::string>& field); ParseContext* ctx, RepeatedPtrField<std::string>& field);
static void UnknownPackedEnum(MessageLite* msg, const TcParseTableBase* table, static void AddUnknownEnum(MessageLite* msg, const TcParseTableBase* table,
uint32_t tag, int32_t enum_value); uint32_t tag, int32_t enum_value);
// Mini field lookup: // Mini field lookup:

@ -1224,16 +1224,34 @@ PROTOBUF_NOINLINE const char* TcParser::FastZ64P2(PROTOBUF_TC_PARAM_DECL) {
PROTOBUF_NOINLINE const char* TcParser::FastUnknownEnumFallback( PROTOBUF_NOINLINE const char* TcParser::FastUnknownEnumFallback(
PROTOBUF_TC_PARAM_DECL) { PROTOBUF_TC_PARAM_DECL) {
// If we know we want to put this field directly into the unknown field set, // Skip MiniParse/fallback and insert the element directly into the unknown
// then we can skip the call to MiniParse and directly call table->fallback. // field set. We also normalize the value into an int32 as we do for known
// However, we first have to update `data` to contain the decoded tag. // enum values.
uint32_t tag; uint32_t tag;
ptr = ReadTag(ptr, &tag); ptr = ReadTag(ptr, &tag);
if (PROTOBUF_PREDICT_FALSE(ptr == nullptr)) { if (PROTOBUF_PREDICT_FALSE(ptr == nullptr)) {
PROTOBUF_MUSTTAIL return Error(PROTOBUF_TC_PARAM_NO_DATA_PASS); PROTOBUF_MUSTTAIL return Error(PROTOBUF_TC_PARAM_NO_DATA_PASS);
} }
data.data = tag; uint64_t tmp;
PROTOBUF_MUSTTAIL return table->fallback(PROTOBUF_TC_PARAM_PASS); ptr = ParseVarint(ptr, &tmp);
if (PROTOBUF_PREDICT_FALSE(ptr == nullptr)) {
PROTOBUF_MUSTTAIL return Error(PROTOBUF_TC_PARAM_NO_DATA_PASS);
}
AddUnknownEnum(msg, table, tag, static_cast<int32_t>(tmp));
PROTOBUF_MUSTTAIL return ToTagDispatch(PROTOBUF_TC_PARAM_NO_DATA_PASS);
}
PROTOBUF_NOINLINE const char* TcParser::MpUnknownEnumFallback(
PROTOBUF_TC_PARAM_DECL) {
// Like FastUnknownEnumFallback, but with the Mp ABI.
uint32_t tag = data.tag();
uint64_t tmp;
ptr = ParseVarint(ptr, &tmp);
if (PROTOBUF_PREDICT_FALSE(ptr == nullptr)) {
PROTOBUF_MUSTTAIL return Error(PROTOBUF_TC_PARAM_NO_DATA_PASS);
}
AddUnknownEnum(msg, table, tag, static_cast<int32_t>(tmp));
PROTOBUF_MUSTTAIL return ToTagDispatch(PROTOBUF_TC_PARAM_NO_DATA_PASS);
} }
template <typename TagType, uint16_t xform_val> template <typename TagType, uint16_t xform_val>
@ -1325,8 +1343,9 @@ const TcParser::UnknownFieldOps& TcParser::GetUnknownFieldOps(
return *reinterpret_cast<const UnknownFieldOps*>(ptr); return *reinterpret_cast<const UnknownFieldOps*>(ptr);
} }
PROTOBUF_NOINLINE void TcParser::UnknownPackedEnum( PROTOBUF_NOINLINE void TcParser::AddUnknownEnum(MessageLite* msg,
MessageLite* msg, const TcParseTableBase* table, uint32_t tag, const TcParseTableBase* table,
uint32_t tag,
int32_t enum_value) { int32_t enum_value) {
GetUnknownFieldOps(table).write_varint(msg, tag >> 3, enum_value); GetUnknownFieldOps(table).write_varint(msg, tag >> 3, enum_value);
} }
@ -1351,7 +1370,7 @@ const char* TcParser::PackedEnum(PROTOBUF_TC_PARAM_DECL) {
const TcParseTableBase::FieldAux aux = *table->field_aux(data.aux_idx()); const TcParseTableBase::FieldAux aux = *table->field_aux(data.aux_idx());
return ctx->ReadPackedVarint(ptr, [=](int32_t value) { return ctx->ReadPackedVarint(ptr, [=](int32_t value) {
if (!EnumIsValidAux(value, xform_val, aux)) { if (!EnumIsValidAux(value, xform_val, aux)) {
UnknownPackedEnum(msg, table, FastDecodeTag(saved_tag), value); AddUnknownEnum(msg, table, FastDecodeTag(saved_tag), value);
} else { } else {
field->Add(value); field->Add(value);
} }
@ -1498,7 +1517,7 @@ const char* TcParser::PackedEnumSmallRange(PROTOBUF_TC_PARAM_DECL) {
return ctx->ReadPackedVarint(ptr, [=](int32_t v) { return ctx->ReadPackedVarint(ptr, [=](int32_t v) {
if (PROTOBUF_PREDICT_FALSE(min > v || v > max)) { if (PROTOBUF_PREDICT_FALSE(min > v || v > max)) {
UnknownPackedEnum(msg, table, FastDecodeTag(saved_tag), v); AddUnknownEnum(msg, table, FastDecodeTag(saved_tag), v);
} else { } else {
field->Add(v); field->Add(v);
} }
@ -2021,7 +2040,7 @@ PROTOBUF_NOINLINE const char* TcParser::MpVarint(PROTOBUF_TC_PARAM_DECL) {
if (is_validated_enum) { if (is_validated_enum) {
if (!EnumIsValidAux(tmp, xform_val, *table->field_aux(&entry))) { if (!EnumIsValidAux(tmp, xform_val, *table->field_aux(&entry))) {
ptr = ptr2; ptr = ptr2;
PROTOBUF_MUSTTAIL return table->fallback(PROTOBUF_TC_PARAM_PASS); PROTOBUF_MUSTTAIL return MpUnknownEnumFallback(PROTOBUF_TC_PARAM_PASS);
} }
} else if (is_zigzag) { } else if (is_zigzag) {
tmp = WireFormatLite::ZigZagDecode32(static_cast<uint32_t>(tmp)); tmp = WireFormatLite::ZigZagDecode32(static_cast<uint32_t>(tmp));
@ -2099,7 +2118,8 @@ PROTOBUF_NOINLINE const char* TcParser::MpRepeatedVarint(
if (is_validated_enum) { if (is_validated_enum) {
if (!EnumIsValidAux(tmp, xform_val, *table->field_aux(&entry))) { if (!EnumIsValidAux(tmp, xform_val, *table->field_aux(&entry))) {
ptr = ptr2; ptr = ptr2;
PROTOBUF_MUSTTAIL return table->fallback(PROTOBUF_TC_PARAM_PASS); PROTOBUF_MUSTTAIL return MpUnknownEnumFallback(
PROTOBUF_TC_PARAM_PASS);
} }
} else if (is_zigzag) { } else if (is_zigzag) {
tmp = WireFormatLite::ZigZagDecode32(tmp); tmp = WireFormatLite::ZigZagDecode32(tmp);
@ -2163,7 +2183,7 @@ PROTOBUF_NOINLINE const char* TcParser::MpPackedVarint(PROTOBUF_TC_PARAM_DECL) {
const TcParseTableBase::FieldAux aux = *table->field_aux(entry.aux_idx); const TcParseTableBase::FieldAux aux = *table->field_aux(entry.aux_idx);
return ctx->ReadPackedVarint(ptr, [=](int32_t value) { return ctx->ReadPackedVarint(ptr, [=](int32_t value) {
if (!EnumIsValidAux(value, xform_val, aux)) { if (!EnumIsValidAux(value, xform_val, aux)) {
UnknownPackedEnum(msg, table, data.tag(), value); AddUnknownEnum(msg, table, data.tag(), value);
} else { } else {
field->Add(value); field->Add(value);
} }

@ -61,6 +61,7 @@
#include "absl/strings/substitute.h" #include "absl/strings/substitute.h"
#include "google/protobuf/arena.h" #include "google/protobuf/arena.h"
#include "google/protobuf/descriptor.h" #include "google/protobuf/descriptor.h"
#include "google/protobuf/dynamic_message.h"
#include "google/protobuf/generated_message_reflection.h" #include "google/protobuf/generated_message_reflection.h"
#include "google/protobuf/generated_message_tctable_impl.h" #include "google/protobuf/generated_message_tctable_impl.h"
#include "google/protobuf/io/coded_stream.h" #include "google/protobuf/io/coded_stream.h"
@ -1310,18 +1311,62 @@ std::string EncodeOverlongEnum(int number, bool use_packed) {
} }
} }
std::string EncodeInt32Value(int number, int32_t value,
int non_canonical_bytes) {
uint8_t buf[100];
uint8_t* p = buf;
p = internal::WireFormatLite::WriteInt32ToArray(number, value, p);
p = AddNonCanonicalBytes(SkipTag(buf), p, non_canonical_bytes);
return std::string(buf, p);
}
std::string EncodeInt64Value(int number, int64_t value, int non_canonical_bytes,
bool use_packed = false) {
uint8_t buf[100];
uint8_t* p = buf;
if (use_packed) {
p = internal::WireFormatLite::WriteInt64NoTagToArray(value, p);
p = AddNonCanonicalBytes(buf, p, non_canonical_bytes);
std::string payload(buf, p);
p = buf;
p = internal::WireFormatLite::WriteStringToArray(number, payload, p);
return std::string(buf, p);
} else {
p = internal::WireFormatLite::WriteInt64ToArray(number, value, p);
p = AddNonCanonicalBytes(SkipTag(buf), p, non_canonical_bytes);
return std::string(buf, p);
}
}
std::string EncodeOtherField() { std::string EncodeOtherField() {
UNITTEST::EnumParseTester obj; UNITTEST::EnumParseTester obj;
obj.set_other_field(1); obj.set_other_field(1);
return obj.SerializeAsString(); return obj.SerializeAsString();
} }
template <typename T>
static std::vector<const FieldDescriptor*> GetFields() {
auto* descriptor = T::descriptor();
std::vector<const FieldDescriptor*> fields;
for (int i = 0; i < descriptor->field_count(); ++i) {
fields.push_back(descriptor->field(i));
}
for (int i = 0; i < descriptor->extension_count(); ++i) {
fields.push_back(descriptor->extension(i));
}
return fields;
}
TEST(MESSAGE_TEST_NAME, TestEnumParsers) { TEST(MESSAGE_TEST_NAME, TestEnumParsers) {
UNITTEST::EnumParseTester obj; UNITTEST::EnumParseTester obj;
const auto other_field = EncodeOtherField(); const auto other_field = EncodeOtherField();
// Encode a boolean field for many different cases and verify that it can be // Encode an enum field for many different cases and verify that it can be
// parsed as expected. // parsed as expected.
// There are: // There are:
// - optional/repeated/packed fields // - optional/repeated/packed fields
@ -1331,6 +1376,9 @@ TEST(MESSAGE_TEST_NAME, TestEnumParsers) {
// - label combinations to trigger different parsers: sequential, small // - label combinations to trigger different parsers: sequential, small
// sequential, non-validated. // sequential, non-validated.
const std::vector<const FieldDescriptor*> fields =
GetFields<UNITTEST::EnumParseTester>();
constexpr int kInvalidValue = 0x900913; constexpr int kInvalidValue = 0x900913;
auto* ref = obj.GetReflection(); auto* ref = obj.GetReflection();
auto* descriptor = obj.descriptor(); auto* descriptor = obj.descriptor();
@ -1347,8 +1395,7 @@ TEST(MESSAGE_TEST_NAME, TestEnumParsers) {
continue; continue;
} }
SCOPED_TRACE(add_garbage_bits); SCOPED_TRACE(add_garbage_bits);
for (int i = 0; i < descriptor->field_count(); ++i) { for (auto field : fields) {
const auto* field = descriptor->field(i);
if (field->name() == "other_field") continue; if (field->name() == "other_field") continue;
if (!field->is_repeated() && use_packed) continue; if (!field->is_repeated() && use_packed) continue;
SCOPED_TRACE(field->full_name()); SCOPED_TRACE(field->full_name());
@ -1421,6 +1468,52 @@ TEST(MESSAGE_TEST_NAME, TestEnumParsers) {
} }
} }
TEST(MESSAGE_TEST_NAME, TestEnumParserForUnknownEnumValue) {
DynamicMessageFactory factory;
std::unique_ptr<Message> dynamic(
factory.GetPrototype(UNITTEST::EnumParseTester::descriptor())->New());
UNITTEST::EnumParseTester non_dynamic;
// For unknown enum values, for consistency we must include the
// int32_t enum value in the unknown field set, which might not be exactly the
// same as the input.
auto* descriptor = non_dynamic.descriptor();
const std::vector<const FieldDescriptor*> fields =
GetFields<UNITTEST::EnumParseTester>();
for (bool use_dynamic : {false, true}) {
SCOPED_TRACE(use_dynamic);
for (auto field : fields) {
if (field->name() == "other_field") continue;
SCOPED_TRACE(field->full_name());
for (bool use_packed : {false, true}) {
SCOPED_TRACE(use_packed);
if (!field->is_repeated() && use_packed) continue;
// -2 is an invalid enum value on all the tests here.
// We will encode -2 as a positive int64 that is equivalent to
// int32_t{-2} when truncated.
constexpr int64_t minus_2_non_canonical =
static_cast<int64_t>(static_cast<uint32_t>(int32_t{-2}));
static_assert(minus_2_non_canonical != -2, "");
std::string encoded = EncodeInt64Value(
field->number(), minus_2_non_canonical, 0, use_packed);
auto& obj = use_dynamic ? *dynamic : non_dynamic;
ASSERT_TRUE(obj.ParseFromString(encoded));
auto& unknown = obj.GetReflection()->GetUnknownFields(obj);
ASSERT_EQ(unknown.field_count(), 1);
EXPECT_EQ(unknown.field(0).number(), field->number());
EXPECT_EQ(unknown.field(0).type(), unknown.field(0).TYPE_VARINT);
EXPECT_EQ(unknown.field(0).varint(), int64_t{-2});
}
}
}
}
std::string EncodeBoolValue(int number, bool value, int non_canonical_bytes) { std::string EncodeBoolValue(int number, bool value, int non_canonical_bytes) {
uint8_t buf[100]; uint8_t buf[100];
uint8_t* p = buf; uint8_t* p = buf;
@ -1443,6 +1536,9 @@ TEST(MESSAGE_TEST_NAME, TestBoolParsers) {
// - canonical and non-canonical encodings of the varint // - canonical and non-canonical encodings of the varint
// - last vs not last field // - last vs not last field
const std::vector<const FieldDescriptor*> fields =
GetFields<UNITTEST::BoolParseTester>();
auto* ref = obj.GetReflection(); auto* ref = obj.GetReflection();
auto* descriptor = obj.descriptor(); auto* descriptor = obj.descriptor();
for (bool use_tail_field : {false, true}) { for (bool use_tail_field : {false, true}) {
@ -1456,8 +1552,7 @@ TEST(MESSAGE_TEST_NAME, TestBoolParsers) {
continue; continue;
} }
SCOPED_TRACE(add_garbage_bits); SCOPED_TRACE(add_garbage_bits);
for (int i = 0; i < descriptor->field_count(); ++i) { for (auto field : fields) {
const auto* field = descriptor->field(i);
if (field->name() == "other_field") continue; if (field->name() == "other_field") continue;
SCOPED_TRACE(field->full_name()); SCOPED_TRACE(field->full_name());
for (bool value : {false, true}) { for (bool value : {false, true}) {
@ -1492,16 +1587,6 @@ TEST(MESSAGE_TEST_NAME, TestBoolParsers) {
} }
} }
std::string EncodeInt32Value(int number, int32_t value,
int non_canonical_bytes) {
uint8_t buf[100];
uint8_t* p = buf;
p = internal::WireFormatLite::WriteInt32ToArray(number, value, p);
p = AddNonCanonicalBytes(SkipTag(buf), p, non_canonical_bytes);
return std::string(buf, p);
}
TEST(MESSAGE_TEST_NAME, TestInt32Parsers) { TEST(MESSAGE_TEST_NAME, TestInt32Parsers) {
UNITTEST::Int32ParseTester obj; UNITTEST::Int32ParseTester obj;
@ -1515,6 +1600,9 @@ TEST(MESSAGE_TEST_NAME, TestInt32Parsers) {
// - canonical and non-canonical encodings of the varint // - canonical and non-canonical encodings of the varint
// - last vs not last field // - last vs not last field
const std::vector<const FieldDescriptor*> fields =
GetFields<UNITTEST::Int32ParseTester>();
auto* ref = obj.GetReflection(); auto* ref = obj.GetReflection();
auto* descriptor = obj.descriptor(); auto* descriptor = obj.descriptor();
for (bool use_tail_field : {false, true}) { for (bool use_tail_field : {false, true}) {
@ -1528,8 +1616,7 @@ TEST(MESSAGE_TEST_NAME, TestInt32Parsers) {
continue; continue;
} }
SCOPED_TRACE(add_garbage_bits); SCOPED_TRACE(add_garbage_bits);
for (int i = 0; i < descriptor->field_count(); ++i) { for (auto field : fields) {
const auto* field = descriptor->field(i);
if (field->name() == "other_field") continue; if (field->name() == "other_field") continue;
SCOPED_TRACE(field->full_name()); SCOPED_TRACE(field->full_name());
for (int32_t value : {1, 0, -1, (std::numeric_limits<int32_t>::min)(), for (int32_t value : {1, 0, -1, (std::numeric_limits<int32_t>::min)(),
@ -1565,16 +1652,6 @@ TEST(MESSAGE_TEST_NAME, TestInt32Parsers) {
} }
} }
std::string EncodeInt64Value(int number, int64_t value,
int non_canonical_bytes) {
uint8_t buf[100];
uint8_t* p = buf;
p = internal::WireFormatLite::WriteInt64ToArray(number, value, p);
p = AddNonCanonicalBytes(SkipTag(buf), p, non_canonical_bytes);
return std::string(buf, p);
}
TEST(MESSAGE_TEST_NAME, TestInt64Parsers) { TEST(MESSAGE_TEST_NAME, TestInt64Parsers) {
UNITTEST::Int64ParseTester obj; UNITTEST::Int64ParseTester obj;
@ -1588,6 +1665,9 @@ TEST(MESSAGE_TEST_NAME, TestInt64Parsers) {
// - canonical and non-canonical encodings of the varint // - canonical and non-canonical encodings of the varint
// - last vs not last field // - last vs not last field
const std::vector<const FieldDescriptor*> fields =
GetFields<UNITTEST::Int64ParseTester>();
auto* ref = obj.GetReflection(); auto* ref = obj.GetReflection();
auto* descriptor = obj.descriptor(); auto* descriptor = obj.descriptor();
for (bool use_tail_field : {false, true}) { for (bool use_tail_field : {false, true}) {
@ -1601,8 +1681,7 @@ TEST(MESSAGE_TEST_NAME, TestInt64Parsers) {
continue; continue;
} }
SCOPED_TRACE(add_garbage_bits); SCOPED_TRACE(add_garbage_bits);
for (int i = 0; i < descriptor->field_count(); ++i) { for (auto field : fields) {
const auto* field = descriptor->field(i);
if (field->name() == "other_field") continue; if (field->name() == "other_field") continue;
SCOPED_TRACE(field->full_name()); SCOPED_TRACE(field->full_name());
for (int64_t value : {int64_t{1}, int64_t{0}, int64_t{-1}, for (int64_t value : {int64_t{1}, int64_t{0}, int64_t{-1},
@ -1748,6 +1827,9 @@ TEST(MESSAGE_TEST_NAME, TestRepeatedStringParsers) {
const auto* const descriptor = UNITTEST::StringParseTester::descriptor(); const auto* const descriptor = UNITTEST::StringParseTester::descriptor();
const std::vector<const FieldDescriptor*> fields =
GetFields<UNITTEST::StringParseTester>();
static const size_t sso_capacity = std::string().capacity(); static const size_t sso_capacity = std::string().capacity();
if (sso_capacity == 0) GTEST_SKIP(); if (sso_capacity == 0) GTEST_SKIP();
// SSO, !SSO, and off-by-one just in case // SSO, !SSO, and off-by-one just in case
@ -1755,8 +1837,7 @@ TEST(MESSAGE_TEST_NAME, TestRepeatedStringParsers) {
{sso_capacity - 1, sso_capacity, sso_capacity + 1, sso_capacity + 2}) { {sso_capacity - 1, sso_capacity, sso_capacity + 1, sso_capacity + 2}) {
SCOPED_TRACE(size); SCOPED_TRACE(size);
const std::string value = sample.substr(0, size); const std::string value = sample.substr(0, size);
for (int i = 0; i < descriptor->field_count(); ++i) { for (auto field : fields) {
const auto* field = descriptor->field(i);
SCOPED_TRACE(field->full_name()); SCOPED_TRACE(field->full_name());
const auto encoded = EncodeStringValue(field->number(), sample) + const auto encoded = EncodeStringValue(field->number(), sample) +
EncodeStringValue(field->number(), value); EncodeStringValue(field->number(), value);

@ -1170,7 +1170,7 @@ PROTOBUF_NODISCARD const char* PackedEnumParser(void* object, const char* ptr,
InternalMetadata* metadata, InternalMetadata* metadata,
int field_num) { int field_num) {
return ctx->ReadPackedVarint( return ctx->ReadPackedVarint(
ptr, [object, is_valid, metadata, field_num](uint64_t val) { ptr, [object, is_valid, metadata, field_num](int32_t val) {
if (is_valid(val)) { if (is_valid(val)) {
static_cast<RepeatedField<int>*>(object)->Add(val); static_cast<RepeatedField<int>*>(object)->Add(val);
} else { } else {
@ -1185,7 +1185,7 @@ PROTOBUF_NODISCARD const char* PackedEnumParserArg(
bool (*is_valid)(const void*, int), const void* data, bool (*is_valid)(const void*, int), const void* data,
InternalMetadata* metadata, int field_num) { InternalMetadata* metadata, int field_num) {
return ctx->ReadPackedVarint( return ctx->ReadPackedVarint(
ptr, [object, is_valid, data, metadata, field_num](uint64_t val) { ptr, [object, is_valid, data, metadata, field_num](int32_t val) {
if (is_valid(data, val)) { if (is_valid(data, val)) {
static_cast<RepeatedField<int>*>(object)->Add(val); static_cast<RepeatedField<int>*>(object)->Add(val);
} else { } else {

@ -1547,6 +1547,13 @@ message EnumParseTester {
repeated Arbitrary packed_arbitrary_midfield = 1012 [packed = true]; repeated Arbitrary packed_arbitrary_midfield = 1012 [packed = true];
repeated Arbitrary packed_arbitrary_hifield = 1000012 [packed = true]; repeated Arbitrary packed_arbitrary_hifield = 1000012 [packed = true];
extensions 2000000 to max;
extend EnumParseTester {
optional Arbitrary optional_arbitrary_ext = 2000000;
repeated Arbitrary repeated_arbitrary_ext = 2000001;
repeated Arbitrary packed_arbitrary_ext = 2000002 [packed = true];
}
// An arbitrary field we can append to to break the runs of repeated fields. // An arbitrary field we can append to to break the runs of repeated fields.
optional int32 other_field = 99; optional int32 other_field = 99;
} }
@ -1564,6 +1571,13 @@ message BoolParseTester {
repeated bool packed_bool_midfield = 1003 [packed = true]; repeated bool packed_bool_midfield = 1003 [packed = true];
repeated bool packed_bool_hifield = 1000003 [packed = true]; repeated bool packed_bool_hifield = 1000003 [packed = true];
extensions 2000000 to max;
extend BoolParseTester {
optional bool optional_bool_ext = 2000000;
repeated bool repeated_bool_ext = 2000001;
repeated bool packed_bool_ext = 2000002 [packed = true];
}
// An arbitrary field we can append to to break the runs of repeated fields. // An arbitrary field we can append to to break the runs of repeated fields.
optional int32 other_field = 99; optional int32 other_field = 99;
} }
@ -1579,6 +1593,13 @@ message Int32ParseTester {
repeated int32 packed_int32_midfield = 1003 [packed = true]; repeated int32 packed_int32_midfield = 1003 [packed = true];
repeated int32 packed_int32_hifield = 1000003 [packed = true]; repeated int32 packed_int32_hifield = 1000003 [packed = true];
extensions 2000000 to max;
extend Int32ParseTester {
optional int32 optional_int32_ext = 2000000;
repeated int32 repeated_int32_ext = 2000001;
repeated int32 packed_int32_ext = 2000002 [packed = true];
}
// An arbitrary field we can append to to break the runs of repeated fields. // An arbitrary field we can append to to break the runs of repeated fields.
optional int32 other_field = 99; optional int32 other_field = 99;
} }
@ -1594,6 +1615,13 @@ message Int64ParseTester {
repeated int64 packed_int64_midfield = 1003 [packed = true]; repeated int64 packed_int64_midfield = 1003 [packed = true];
repeated int64 packed_int64_hifield = 1000003 [packed = true]; repeated int64 packed_int64_hifield = 1000003 [packed = true];
extensions 2000000 to max;
extend Int64ParseTester {
optional int64 optional_int64_ext = 2000000;
repeated int64 repeated_int64_ext = 2000001;
repeated int64 packed_int64_ext = 2000002 [packed = true];
}
// An arbitrary field we can append to to break the runs of repeated fields. // An arbitrary field we can append to to break the runs of repeated fields.
optional int32 other_field = 99; optional int32 other_field = 99;
} }
@ -1617,6 +1645,12 @@ message StringParseTester {
repeated string repeated_string_lowfield = 2; repeated string repeated_string_lowfield = 2;
repeated string repeated_string_midfield = 1002; repeated string repeated_string_midfield = 1002;
repeated string repeated_string_hifield = 1000002; repeated string repeated_string_hifield = 1000002;
extensions 2000000 to max;
extend StringParseTester {
optional string optional_string_ext = 2000000;
repeated string repeated_string_ext = 2000001;
}
} }
message BadFieldNames{ message BadFieldNames{

@ -874,7 +874,7 @@ const char* WireFormat::_InternalParseAndMergeField(
ptr = internal::PackedEnumParser(rep_enum, ptr, ctx); ptr = internal::PackedEnumParser(rep_enum, ptr, ctx);
} else { } else {
return ctx->ReadPackedVarint( return ctx->ReadPackedVarint(
ptr, [rep_enum, field, reflection, msg](uint64_t val) { ptr, [rep_enum, field, reflection, msg](int32_t val) {
if (field->enum_type()->FindValueByNumber(val) != nullptr) { if (field->enum_type()->FindValueByNumber(val) != nullptr) {
rep_enum->Add(val); rep_enum->Add(val);
} else { } else {

Loading…
Cancel
Save