Unify MessageSchema.parseMessage proto2 and proto3 codepaths

This handles the following proto2/3 differences in single parseMessage codepath that works for proto2, proto3, and editions
- Groups (proto2)
- Open (proto3) vs closed (proto2) enums, incl closed enums in unknown fields
- Extensions (proto2)
- No presence (proto3)

PiperOrigin-RevId: 542872685
pull/13133/head
Sandy Zhang 1 year ago committed by Copybara-Service
parent 8113bdef84
commit e5936049ae
  1. 3
      java/core/src/main/java/com/google/protobuf/ArrayDecoders.java
  2. 311
      java/core/src/main/java/com/google/protobuf/MessageSchema.java
  3. 1
      src/google/protobuf/compiler/java/helpers.h

@ -286,9 +286,8 @@ final class ArrayDecoders {
// A group field must has a MessageSchema (the only other subclass of Schema is MessageSetSchema
// and it can't be used in group fields).
final MessageSchema messageSchema = (MessageSchema) schema;
// It's OK to directly use parseProto2Message since proto3 doesn't have group.
final int endPosition =
messageSchema.parseProto2Message(msg, data, position, limit, endGroup, registers);
messageSchema.parseMessage(msg, data, position, limit, endGroup, registers);
registers.object1 = msg;
return endPosition;
}

@ -91,6 +91,7 @@ final class MessageSchema<T> implements Schema<T> {
private static final int FIELD_TYPE_MASK = 0x0FF00000;
private static final int REQUIRED_MASK = 0x10000000;
private static final int ENFORCE_UTF8_MASK = 0x20000000;
private static final int LEGACY_ENUM_IS_CLOSED_MASK = 0x80000000;
private static final int NO_PRESENCE_SENTINEL = -1 & OFFSET_MASK;
private static final int[] EMPTY_INT_ARRAY = new int[0];
@ -593,6 +594,9 @@ final class MessageSchema<T> implements Schema<T> {
buffer[bufferIndex++] =
((fieldTypeWithExtraBits & UTF8_CHECK_BIT) != 0 ? ENFORCE_UTF8_MASK : 0)
| ((fieldTypeWithExtraBits & REQUIRED_BIT) != 0 ? REQUIRED_MASK : 0)
| ((fieldTypeWithExtraBits & LEGACY_ENUM_IS_CLOSED_BIT) != 0
? LEGACY_ENUM_IS_CLOSED_MASK
: 0)
| (fieldType << OFFSET_BITS)
| fieldOffset;
buffer[bufferIndex++] = (presenceMaskShift << OFFSET_BITS) | presenceFieldOffset;
@ -3942,14 +3946,14 @@ final class MessageSchema<T> implements Schema<T> {
}
/**
* Parses a proto2 message or group and returns the position after the message/group. If it's
* parsing a message (endGroup == 0), returns limit if parsing is successful; It it's parsing a
* group (endGroup != 0), parsing ends when a tag == endGroup is encountered and the position
* after that tag is returned.
* Parses a message and returns the position after the message/group. If it's parsing a
* LENGTH_PREFIXED message (endDelimited == 0), returns limit if parsing is successful; If it's
* parsing a DELIMITED message aka group (endDelimited != 0), parsing ends when a tag ==
* endDelimited is encountered and the position after that tag is returned.
*/
@CanIgnoreReturnValue
int parseProto2Message(
T message, byte[] data, int position, int limit, int endGroup, Registers registers)
int parseMessage(
T message, byte[] data, int position, int limit, int endDelimited, Registers registers)
throws IOException {
checkMutable(message);
final sun.misc.Unsafe unsafe = UNSAFE;
@ -3980,18 +3984,23 @@ final class MessageSchema<T> implements Schema<T> {
final int fieldType = type(typeAndOffset);
final long fieldOffset = offset(typeAndOffset);
if (fieldType <= 17) {
// Proto2 optional fields have has-bits.
// Fields with explicit presence (i.e. optional) have has-bits.
final int presenceMaskAndOffset = buffer[pos + 2];
final int presenceMask = 1 << (presenceMaskAndOffset >>> OFFSET_BITS);
final int presenceFieldOffset = presenceMaskAndOffset & OFFSET_MASK;
// We cache the 32-bit has-bits integer value and only write it back when parsing a field
// using a different has-bits integer.
// We cache the 32-bit presence integer value and only write it back when parsing a field
// using a different presence integer.
if (presenceFieldOffset != currentPresenceFieldOffset) {
if (currentPresenceFieldOffset != NO_PRESENCE_SENTINEL) {
unsafe.putInt(message, (long) currentPresenceFieldOffset, currentPresenceField);
}
currentPresenceFieldOffset = presenceFieldOffset;
currentPresenceField = unsafe.getInt(message, (long) presenceFieldOffset);
// For fields without presence, we unconditionally write and discard
// the data.
currentPresenceField =
presenceFieldOffset == NO_PRESENCE_SENTINEL
? 0
: unsafe.getInt(message, (long) presenceFieldOffset);
}
switch (fieldType) {
case 0: // DOUBLE
@ -4056,10 +4065,10 @@ final class MessageSchema<T> implements Schema<T> {
break;
case 8: // STRING
if (wireType == WireFormat.WIRETYPE_LENGTH_DELIMITED) {
if ((typeAndOffset & ENFORCE_UTF8_MASK) == 0) {
position = decodeString(data, position, registers);
} else {
if (isEnforceUtf8(typeAndOffset)) {
position = decodeStringRequireUtf8(data, position, registers);
} else {
position = decodeString(data, position, registers);
}
unsafe.putObject(message, fieldOffset, registers.object1);
currentPresenceField |= presenceMask;
@ -4090,10 +4099,14 @@ final class MessageSchema<T> implements Schema<T> {
position = decodeVarint32(data, position, registers);
final int enumValue = registers.int1;
EnumVerifier enumVerifier = getEnumFieldVerifier(pos);
if (enumVerifier == null || enumVerifier.isInRange(enumValue)) {
if (!isLegacyEnumIsClosed(typeAndOffset)
|| enumVerifier == null
|| enumVerifier.isInRange(enumValue)) {
// Parse open enums and in-range closed enums into their fields directly.
unsafe.putInt(message, fieldOffset, enumValue);
currentPresenceField |= presenceMask;
} else {
// Store out-of-range closed enums in unknown fields.
// UnknownFieldSetLite requires varint to be represented as Long.
getMutableUnknownFields(message).storeField(tag, (long) enumValue);
}
@ -4141,7 +4154,7 @@ final class MessageSchema<T> implements Schema<T> {
break;
}
} else if (fieldType == 27) {
// Handle repeated message fields.
// Handle repeated message field.
if (wireType == WireFormat.WIRETYPE_LENGTH_DELIMITED) {
ProtobufList<?> list = (ProtobufList<?>) unsafe.getObject(message, fieldOffset);
if (!list.isModifiable()) {
@ -4205,7 +4218,7 @@ final class MessageSchema<T> implements Schema<T> {
}
}
}
if (tag == endGroup && endGroup != 0) {
if (tag == endDelimited && endDelimited != 0) {
break;
}
@ -4237,12 +4250,12 @@ final class MessageSchema<T> implements Schema<T> {
((UnknownFieldSchema<UnknownFieldSetLite, UnknownFieldSetLite>) unknownFieldSchema)
.setBuilderToMessage(message, unknownFields);
}
if (endGroup == 0) {
if (endDelimited == 0) {
if (position != limit) {
throw InvalidProtocolBufferException.parseFailure();
}
} else {
if (position > limit || tag != endGroup) {
if (position > limit || tag != endDelimited) {
throw InvalidProtocolBufferException.parseFailure();
}
}
@ -4304,266 +4317,10 @@ final class MessageSchema<T> implements Schema<T> {
setOneofPresent(message, fieldNumber, pos);
}
/** Parses a proto3 message and returns the limit if parsing is successful. */
@CanIgnoreReturnValue
private int parseProto3Message(
T message, byte[] data, int position, int limit, Registers registers) throws IOException {
checkMutable(message);
final sun.misc.Unsafe unsafe = UNSAFE;
int currentPresenceFieldOffset = NO_PRESENCE_SENTINEL;
int currentPresenceField = 0;
int tag = 0;
int oldNumber = -1;
int pos = 0;
while (position < limit) {
tag = data[position++];
if (tag < 0) {
position = decodeVarint32(tag, data, position, registers);
tag = registers.int1;
}
final int number = tag >>> 3;
final int wireType = tag & 0x7;
if (number > oldNumber) {
pos = positionForFieldNumber(number, pos / INTS_PER_FIELD);
} else {
pos = positionForFieldNumber(number);
}
oldNumber = number;
if (pos == -1) {
// need to reset
pos = 0;
} else {
final int typeAndOffset = buffer[pos + 1];
final int fieldType = type(typeAndOffset);
final long fieldOffset = offset(typeAndOffset);
if (fieldType <= 17) {
// Proto3 optional fields have has-bits.
final int presenceMaskAndOffset = buffer[pos + 2];
final int presenceMask = 1 << (presenceMaskAndOffset >>> OFFSET_BITS);
final int presenceFieldOffset = presenceMaskAndOffset & OFFSET_MASK;
// We cache the 32-bit has-bits integer value and only write it back when parsing a field
// using a different has-bits integer.
//
// Note that for fields that do not have hasbits, we unconditionally write and discard
// the data.
if (presenceFieldOffset != currentPresenceFieldOffset) {
if (currentPresenceFieldOffset != NO_PRESENCE_SENTINEL) {
unsafe.putInt(message, (long) currentPresenceFieldOffset, currentPresenceField);
}
if (presenceFieldOffset != NO_PRESENCE_SENTINEL) {
currentPresenceField = unsafe.getInt(message, (long) presenceFieldOffset);
}
currentPresenceFieldOffset = presenceFieldOffset;
}
switch (fieldType) {
case 0: // DOUBLE:
if (wireType == WireFormat.WIRETYPE_FIXED64) {
UnsafeUtil.putDouble(message, fieldOffset, decodeDouble(data, position));
position += 8;
currentPresenceField |= presenceMask;
continue;
}
break;
case 1: // FLOAT:
if (wireType == WireFormat.WIRETYPE_FIXED32) {
UnsafeUtil.putFloat(message, fieldOffset, decodeFloat(data, position));
position += 4;
currentPresenceField |= presenceMask;
continue;
}
break;
case 2: // INT64:
case 3: // UINT64:
if (wireType == WireFormat.WIRETYPE_VARINT) {
position = decodeVarint64(data, position, registers);
unsafe.putLong(message, fieldOffset, registers.long1);
currentPresenceField |= presenceMask;
continue;
}
break;
case 4: // INT32:
case 11: // UINT32:
if (wireType == WireFormat.WIRETYPE_VARINT) {
position = decodeVarint32(data, position, registers);
unsafe.putInt(message, fieldOffset, registers.int1);
currentPresenceField |= presenceMask;
continue;
}
break;
case 5: // FIXED64:
case 14: // SFIXED64:
if (wireType == WireFormat.WIRETYPE_FIXED64) {
unsafe.putLong(message, fieldOffset, decodeFixed64(data, position));
position += 8;
currentPresenceField |= presenceMask;
continue;
}
break;
case 6: // FIXED32:
case 13: // SFIXED32:
if (wireType == WireFormat.WIRETYPE_FIXED32) {
unsafe.putInt(message, fieldOffset, decodeFixed32(data, position));
position += 4;
currentPresenceField |= presenceMask;
continue;
}
break;
case 7: // BOOL:
if (wireType == WireFormat.WIRETYPE_VARINT) {
position = decodeVarint64(data, position, registers);
UnsafeUtil.putBoolean(message, fieldOffset, registers.long1 != 0);
currentPresenceField |= presenceMask;
continue;
}
break;
case 8: // STRING:
if (wireType == WireFormat.WIRETYPE_LENGTH_DELIMITED) {
if ((typeAndOffset & ENFORCE_UTF8_MASK) == 0) {
position = decodeString(data, position, registers);
} else {
position = decodeStringRequireUtf8(data, position, registers);
}
unsafe.putObject(message, fieldOffset, registers.object1);
currentPresenceField |= presenceMask;
continue;
}
break;
case 9: // MESSAGE:
if (wireType == WireFormat.WIRETYPE_LENGTH_DELIMITED) {
final Object current = mutableMessageFieldForMerge(message, pos);
position =
mergeMessageField(
current, getMessageFieldSchema(pos), data, position, limit, registers);
storeMessageField(message, pos, current);
currentPresenceField |= presenceMask;
continue;
}
break;
case 10: // BYTES:
if (wireType == WireFormat.WIRETYPE_LENGTH_DELIMITED) {
position = decodeBytes(data, position, registers);
unsafe.putObject(message, fieldOffset, registers.object1);
currentPresenceField |= presenceMask;
continue;
}
break;
case 12: // ENUM:
if (wireType == WireFormat.WIRETYPE_VARINT) {
position = decodeVarint32(data, position, registers);
unsafe.putInt(message, fieldOffset, registers.int1);
currentPresenceField |= presenceMask;
continue;
}
break;
case 15: // SINT32:
if (wireType == WireFormat.WIRETYPE_VARINT) {
position = decodeVarint32(data, position, registers);
unsafe.putInt(
message, fieldOffset, CodedInputStream.decodeZigZag32(registers.int1));
currentPresenceField |= presenceMask;
continue;
}
break;
case 16: // SINT64:
if (wireType == WireFormat.WIRETYPE_VARINT) {
position = decodeVarint64(data, position, registers);
unsafe.putLong(
message, fieldOffset, CodedInputStream.decodeZigZag64(registers.long1));
currentPresenceField |= presenceMask;
continue;
}
break;
default:
break;
}
} else if (fieldType == 27) {
// Handle repeated message field.
if (wireType == WireFormat.WIRETYPE_LENGTH_DELIMITED) {
ProtobufList<?> list = (ProtobufList<?>) unsafe.getObject(message, fieldOffset);
if (!list.isModifiable()) {
final int size = list.size();
list =
list.mutableCopyWithCapacity(
size == 0 ? AbstractProtobufList.DEFAULT_CAPACITY : size * 2);
unsafe.putObject(message, fieldOffset, list);
}
position =
decodeMessageList(
getMessageFieldSchema(pos), tag, data, position, limit, list, registers);
continue;
}
} else if (fieldType <= 49) {
// Handle all other repeated fields.
final int oldPosition = position;
position =
parseRepeatedField(
message,
data,
position,
limit,
tag,
number,
wireType,
pos,
typeAndOffset,
fieldType,
fieldOffset,
registers);
if (position != oldPosition) {
continue;
}
} else if (fieldType == 50) {
if (wireType == WireFormat.WIRETYPE_LENGTH_DELIMITED) {
final int oldPosition = position;
position = parseMapField(message, data, position, limit, pos, fieldOffset, registers);
if (position != oldPosition) {
continue;
}
}
} else {
final int oldPosition = position;
position =
parseOneofField(
message,
data,
position,
limit,
tag,
number,
wireType,
typeAndOffset,
fieldType,
fieldOffset,
pos,
registers);
if (position != oldPosition) {
continue;
}
}
}
position = decodeUnknownField(
tag, data, position, limit, getMutableUnknownFields(message), registers);
}
if (currentPresenceFieldOffset != NO_PRESENCE_SENTINEL) {
unsafe.putInt(message, (long) currentPresenceFieldOffset, currentPresenceField);
}
if (position != limit) {
throw InvalidProtocolBufferException.parseFailure();
}
return position;
}
@Override
public void mergeFrom(T message, byte[] data, int position, int limit, Registers registers)
throws IOException {
switch (syntax) {
case PROTO3:
parseProto3Message(message, data, position, limit, registers);
break;
case PROTO2:
parseProto2Message(message, data, position, limit, 0, registers);
break;
}
parseMessage(message, data, position, limit, 0, registers);
}
@Override
@ -4935,6 +4692,10 @@ final class MessageSchema<T> implements Schema<T> {
return (value & ENFORCE_UTF8_MASK) != 0;
}
private static boolean isLegacyEnumIsClosed(int value) {
return (value & LEGACY_ENUM_IS_CLOSED_MASK) != 0;
}
private static long offset(int value) {
return value & OFFSET_MASK;
}

@ -372,6 +372,7 @@ inline bool ExposePublicParser(const FileDescriptor* descriptor) {
// but in the message and can be queried using additional getters that return
// ints.
inline bool SupportUnknownEnumValue(const FieldDescriptor* field) {
// TODO(b/279034699): Check Java legacy_enum_field_treated_as_closed feature.
return !field->legacy_enum_field_treated_as_closed();
}

Loading…
Cancel
Save