Unify MessageSchema.parseMessage proto2 and proto3 codepaths

This handles the following proto2/3 differences in single parseMessage codepath that works for proto2, proto3, and editions
- Groups (proto2)
- Open (proto3) vs closed (proto2) enums, incl closed enums in unknown fields
- Extensions (proto2)
- No presence (proto3)

PiperOrigin-RevId: 542872685
pull/13133/head
Sandy Zhang 1 year ago committed by Copybara-Service
parent 8113bdef84
commit e5936049ae
  1. 3
      java/core/src/main/java/com/google/protobuf/ArrayDecoders.java
  2. 311
      java/core/src/main/java/com/google/protobuf/MessageSchema.java
  3. 1
      src/google/protobuf/compiler/java/helpers.h

@ -286,9 +286,8 @@ final class ArrayDecoders {
// A group field must has a MessageSchema (the only other subclass of Schema is MessageSetSchema // A group field must has a MessageSchema (the only other subclass of Schema is MessageSetSchema
// and it can't be used in group fields). // and it can't be used in group fields).
final MessageSchema messageSchema = (MessageSchema) schema; final MessageSchema messageSchema = (MessageSchema) schema;
// It's OK to directly use parseProto2Message since proto3 doesn't have group.
final int endPosition = final int endPosition =
messageSchema.parseProto2Message(msg, data, position, limit, endGroup, registers); messageSchema.parseMessage(msg, data, position, limit, endGroup, registers);
registers.object1 = msg; registers.object1 = msg;
return endPosition; return endPosition;
} }

@ -91,6 +91,7 @@ final class MessageSchema<T> implements Schema<T> {
private static final int FIELD_TYPE_MASK = 0x0FF00000; private static final int FIELD_TYPE_MASK = 0x0FF00000;
private static final int REQUIRED_MASK = 0x10000000; private static final int REQUIRED_MASK = 0x10000000;
private static final int ENFORCE_UTF8_MASK = 0x20000000; private static final int ENFORCE_UTF8_MASK = 0x20000000;
private static final int LEGACY_ENUM_IS_CLOSED_MASK = 0x80000000;
private static final int NO_PRESENCE_SENTINEL = -1 & OFFSET_MASK; private static final int NO_PRESENCE_SENTINEL = -1 & OFFSET_MASK;
private static final int[] EMPTY_INT_ARRAY = new int[0]; private static final int[] EMPTY_INT_ARRAY = new int[0];
@ -593,6 +594,9 @@ final class MessageSchema<T> implements Schema<T> {
buffer[bufferIndex++] = buffer[bufferIndex++] =
((fieldTypeWithExtraBits & UTF8_CHECK_BIT) != 0 ? ENFORCE_UTF8_MASK : 0) ((fieldTypeWithExtraBits & UTF8_CHECK_BIT) != 0 ? ENFORCE_UTF8_MASK : 0)
| ((fieldTypeWithExtraBits & REQUIRED_BIT) != 0 ? REQUIRED_MASK : 0) | ((fieldTypeWithExtraBits & REQUIRED_BIT) != 0 ? REQUIRED_MASK : 0)
| ((fieldTypeWithExtraBits & LEGACY_ENUM_IS_CLOSED_BIT) != 0
? LEGACY_ENUM_IS_CLOSED_MASK
: 0)
| (fieldType << OFFSET_BITS) | (fieldType << OFFSET_BITS)
| fieldOffset; | fieldOffset;
buffer[bufferIndex++] = (presenceMaskShift << OFFSET_BITS) | presenceFieldOffset; buffer[bufferIndex++] = (presenceMaskShift << OFFSET_BITS) | presenceFieldOffset;
@ -3942,14 +3946,14 @@ final class MessageSchema<T> implements Schema<T> {
} }
/** /**
* Parses a proto2 message or group and returns the position after the message/group. If it's * Parses a message and returns the position after the message/group. If it's parsing a
* parsing a message (endGroup == 0), returns limit if parsing is successful; It it's parsing a * LENGTH_PREFIXED message (endDelimited == 0), returns limit if parsing is successful; If it's
* group (endGroup != 0), parsing ends when a tag == endGroup is encountered and the position * parsing a DELIMITED message aka group (endDelimited != 0), parsing ends when a tag ==
* after that tag is returned. * endDelimited is encountered and the position after that tag is returned.
*/ */
@CanIgnoreReturnValue @CanIgnoreReturnValue
int parseProto2Message( int parseMessage(
T message, byte[] data, int position, int limit, int endGroup, Registers registers) T message, byte[] data, int position, int limit, int endDelimited, Registers registers)
throws IOException { throws IOException {
checkMutable(message); checkMutable(message);
final sun.misc.Unsafe unsafe = UNSAFE; final sun.misc.Unsafe unsafe = UNSAFE;
@ -3980,18 +3984,23 @@ final class MessageSchema<T> implements Schema<T> {
final int fieldType = type(typeAndOffset); final int fieldType = type(typeAndOffset);
final long fieldOffset = offset(typeAndOffset); final long fieldOffset = offset(typeAndOffset);
if (fieldType <= 17) { if (fieldType <= 17) {
// Proto2 optional fields have has-bits. // Fields with explicit presence (i.e. optional) have has-bits.
final int presenceMaskAndOffset = buffer[pos + 2]; final int presenceMaskAndOffset = buffer[pos + 2];
final int presenceMask = 1 << (presenceMaskAndOffset >>> OFFSET_BITS); final int presenceMask = 1 << (presenceMaskAndOffset >>> OFFSET_BITS);
final int presenceFieldOffset = presenceMaskAndOffset & OFFSET_MASK; final int presenceFieldOffset = presenceMaskAndOffset & OFFSET_MASK;
// We cache the 32-bit has-bits integer value and only write it back when parsing a field // We cache the 32-bit presence integer value and only write it back when parsing a field
// using a different has-bits integer. // using a different presence integer.
if (presenceFieldOffset != currentPresenceFieldOffset) { if (presenceFieldOffset != currentPresenceFieldOffset) {
if (currentPresenceFieldOffset != NO_PRESENCE_SENTINEL) { if (currentPresenceFieldOffset != NO_PRESENCE_SENTINEL) {
unsafe.putInt(message, (long) currentPresenceFieldOffset, currentPresenceField); unsafe.putInt(message, (long) currentPresenceFieldOffset, currentPresenceField);
} }
currentPresenceFieldOffset = presenceFieldOffset; currentPresenceFieldOffset = presenceFieldOffset;
currentPresenceField = unsafe.getInt(message, (long) presenceFieldOffset); // For fields without presence, we unconditionally write and discard
// the data.
currentPresenceField =
presenceFieldOffset == NO_PRESENCE_SENTINEL
? 0
: unsafe.getInt(message, (long) presenceFieldOffset);
} }
switch (fieldType) { switch (fieldType) {
case 0: // DOUBLE case 0: // DOUBLE
@ -4056,10 +4065,10 @@ final class MessageSchema<T> implements Schema<T> {
break; break;
case 8: // STRING case 8: // STRING
if (wireType == WireFormat.WIRETYPE_LENGTH_DELIMITED) { if (wireType == WireFormat.WIRETYPE_LENGTH_DELIMITED) {
if ((typeAndOffset & ENFORCE_UTF8_MASK) == 0) { if (isEnforceUtf8(typeAndOffset)) {
position = decodeString(data, position, registers);
} else {
position = decodeStringRequireUtf8(data, position, registers); position = decodeStringRequireUtf8(data, position, registers);
} else {
position = decodeString(data, position, registers);
} }
unsafe.putObject(message, fieldOffset, registers.object1); unsafe.putObject(message, fieldOffset, registers.object1);
currentPresenceField |= presenceMask; currentPresenceField |= presenceMask;
@ -4090,10 +4099,14 @@ final class MessageSchema<T> implements Schema<T> {
position = decodeVarint32(data, position, registers); position = decodeVarint32(data, position, registers);
final int enumValue = registers.int1; final int enumValue = registers.int1;
EnumVerifier enumVerifier = getEnumFieldVerifier(pos); EnumVerifier enumVerifier = getEnumFieldVerifier(pos);
if (enumVerifier == null || enumVerifier.isInRange(enumValue)) { if (!isLegacyEnumIsClosed(typeAndOffset)
|| enumVerifier == null
|| enumVerifier.isInRange(enumValue)) {
// Parse open enums and in-range closed enums into their fields directly.
unsafe.putInt(message, fieldOffset, enumValue); unsafe.putInt(message, fieldOffset, enumValue);
currentPresenceField |= presenceMask; currentPresenceField |= presenceMask;
} else { } else {
// Store out-of-range closed enums in unknown fields.
// UnknownFieldSetLite requires varint to be represented as Long. // UnknownFieldSetLite requires varint to be represented as Long.
getMutableUnknownFields(message).storeField(tag, (long) enumValue); getMutableUnknownFields(message).storeField(tag, (long) enumValue);
} }
@ -4141,7 +4154,7 @@ final class MessageSchema<T> implements Schema<T> {
break; break;
} }
} else if (fieldType == 27) { } else if (fieldType == 27) {
// Handle repeated message fields. // Handle repeated message field.
if (wireType == WireFormat.WIRETYPE_LENGTH_DELIMITED) { if (wireType == WireFormat.WIRETYPE_LENGTH_DELIMITED) {
ProtobufList<?> list = (ProtobufList<?>) unsafe.getObject(message, fieldOffset); ProtobufList<?> list = (ProtobufList<?>) unsafe.getObject(message, fieldOffset);
if (!list.isModifiable()) { if (!list.isModifiable()) {
@ -4205,7 +4218,7 @@ final class MessageSchema<T> implements Schema<T> {
} }
} }
} }
if (tag == endGroup && endGroup != 0) { if (tag == endDelimited && endDelimited != 0) {
break; break;
} }
@ -4237,12 +4250,12 @@ final class MessageSchema<T> implements Schema<T> {
((UnknownFieldSchema<UnknownFieldSetLite, UnknownFieldSetLite>) unknownFieldSchema) ((UnknownFieldSchema<UnknownFieldSetLite, UnknownFieldSetLite>) unknownFieldSchema)
.setBuilderToMessage(message, unknownFields); .setBuilderToMessage(message, unknownFields);
} }
if (endGroup == 0) { if (endDelimited == 0) {
if (position != limit) { if (position != limit) {
throw InvalidProtocolBufferException.parseFailure(); throw InvalidProtocolBufferException.parseFailure();
} }
} else { } else {
if (position > limit || tag != endGroup) { if (position > limit || tag != endDelimited) {
throw InvalidProtocolBufferException.parseFailure(); throw InvalidProtocolBufferException.parseFailure();
} }
} }
@ -4304,266 +4317,10 @@ final class MessageSchema<T> implements Schema<T> {
setOneofPresent(message, fieldNumber, pos); setOneofPresent(message, fieldNumber, pos);
} }
/** Parses a proto3 message and returns the limit if parsing is successful. */
@CanIgnoreReturnValue
private int parseProto3Message(
T message, byte[] data, int position, int limit, Registers registers) throws IOException {
checkMutable(message);
final sun.misc.Unsafe unsafe = UNSAFE;
int currentPresenceFieldOffset = NO_PRESENCE_SENTINEL;
int currentPresenceField = 0;
int tag = 0;
int oldNumber = -1;
int pos = 0;
while (position < limit) {
tag = data[position++];
if (tag < 0) {
position = decodeVarint32(tag, data, position, registers);
tag = registers.int1;
}
final int number = tag >>> 3;
final int wireType = tag & 0x7;
if (number > oldNumber) {
pos = positionForFieldNumber(number, pos / INTS_PER_FIELD);
} else {
pos = positionForFieldNumber(number);
}
oldNumber = number;
if (pos == -1) {
// need to reset
pos = 0;
} else {
final int typeAndOffset = buffer[pos + 1];
final int fieldType = type(typeAndOffset);
final long fieldOffset = offset(typeAndOffset);
if (fieldType <= 17) {
// Proto3 optional fields have has-bits.
final int presenceMaskAndOffset = buffer[pos + 2];
final int presenceMask = 1 << (presenceMaskAndOffset >>> OFFSET_BITS);
final int presenceFieldOffset = presenceMaskAndOffset & OFFSET_MASK;
// We cache the 32-bit has-bits integer value and only write it back when parsing a field
// using a different has-bits integer.
//
// Note that for fields that do not have hasbits, we unconditionally write and discard
// the data.
if (presenceFieldOffset != currentPresenceFieldOffset) {
if (currentPresenceFieldOffset != NO_PRESENCE_SENTINEL) {
unsafe.putInt(message, (long) currentPresenceFieldOffset, currentPresenceField);
}
if (presenceFieldOffset != NO_PRESENCE_SENTINEL) {
currentPresenceField = unsafe.getInt(message, (long) presenceFieldOffset);
}
currentPresenceFieldOffset = presenceFieldOffset;
}
switch (fieldType) {
case 0: // DOUBLE:
if (wireType == WireFormat.WIRETYPE_FIXED64) {
UnsafeUtil.putDouble(message, fieldOffset, decodeDouble(data, position));
position += 8;
currentPresenceField |= presenceMask;
continue;
}
break;
case 1: // FLOAT:
if (wireType == WireFormat.WIRETYPE_FIXED32) {
UnsafeUtil.putFloat(message, fieldOffset, decodeFloat(data, position));
position += 4;
currentPresenceField |= presenceMask;
continue;
}
break;
case 2: // INT64:
case 3: // UINT64:
if (wireType == WireFormat.WIRETYPE_VARINT) {
position = decodeVarint64(data, position, registers);
unsafe.putLong(message, fieldOffset, registers.long1);
currentPresenceField |= presenceMask;
continue;
}
break;
case 4: // INT32:
case 11: // UINT32:
if (wireType == WireFormat.WIRETYPE_VARINT) {
position = decodeVarint32(data, position, registers);
unsafe.putInt(message, fieldOffset, registers.int1);
currentPresenceField |= presenceMask;
continue;
}
break;
case 5: // FIXED64:
case 14: // SFIXED64:
if (wireType == WireFormat.WIRETYPE_FIXED64) {
unsafe.putLong(message, fieldOffset, decodeFixed64(data, position));
position += 8;
currentPresenceField |= presenceMask;
continue;
}
break;
case 6: // FIXED32:
case 13: // SFIXED32:
if (wireType == WireFormat.WIRETYPE_FIXED32) {
unsafe.putInt(message, fieldOffset, decodeFixed32(data, position));
position += 4;
currentPresenceField |= presenceMask;
continue;
}
break;
case 7: // BOOL:
if (wireType == WireFormat.WIRETYPE_VARINT) {
position = decodeVarint64(data, position, registers);
UnsafeUtil.putBoolean(message, fieldOffset, registers.long1 != 0);
currentPresenceField |= presenceMask;
continue;
}
break;
case 8: // STRING:
if (wireType == WireFormat.WIRETYPE_LENGTH_DELIMITED) {
if ((typeAndOffset & ENFORCE_UTF8_MASK) == 0) {
position = decodeString(data, position, registers);
} else {
position = decodeStringRequireUtf8(data, position, registers);
}
unsafe.putObject(message, fieldOffset, registers.object1);
currentPresenceField |= presenceMask;
continue;
}
break;
case 9: // MESSAGE:
if (wireType == WireFormat.WIRETYPE_LENGTH_DELIMITED) {
final Object current = mutableMessageFieldForMerge(message, pos);
position =
mergeMessageField(
current, getMessageFieldSchema(pos), data, position, limit, registers);
storeMessageField(message, pos, current);
currentPresenceField |= presenceMask;
continue;
}
break;
case 10: // BYTES:
if (wireType == WireFormat.WIRETYPE_LENGTH_DELIMITED) {
position = decodeBytes(data, position, registers);
unsafe.putObject(message, fieldOffset, registers.object1);
currentPresenceField |= presenceMask;
continue;
}
break;
case 12: // ENUM:
if (wireType == WireFormat.WIRETYPE_VARINT) {
position = decodeVarint32(data, position, registers);
unsafe.putInt(message, fieldOffset, registers.int1);
currentPresenceField |= presenceMask;
continue;
}
break;
case 15: // SINT32:
if (wireType == WireFormat.WIRETYPE_VARINT) {
position = decodeVarint32(data, position, registers);
unsafe.putInt(
message, fieldOffset, CodedInputStream.decodeZigZag32(registers.int1));
currentPresenceField |= presenceMask;
continue;
}
break;
case 16: // SINT64:
if (wireType == WireFormat.WIRETYPE_VARINT) {
position = decodeVarint64(data, position, registers);
unsafe.putLong(
message, fieldOffset, CodedInputStream.decodeZigZag64(registers.long1));
currentPresenceField |= presenceMask;
continue;
}
break;
default:
break;
}
} else if (fieldType == 27) {
// Handle repeated message field.
if (wireType == WireFormat.WIRETYPE_LENGTH_DELIMITED) {
ProtobufList<?> list = (ProtobufList<?>) unsafe.getObject(message, fieldOffset);
if (!list.isModifiable()) {
final int size = list.size();
list =
list.mutableCopyWithCapacity(
size == 0 ? AbstractProtobufList.DEFAULT_CAPACITY : size * 2);
unsafe.putObject(message, fieldOffset, list);
}
position =
decodeMessageList(
getMessageFieldSchema(pos), tag, data, position, limit, list, registers);
continue;
}
} else if (fieldType <= 49) {
// Handle all other repeated fields.
final int oldPosition = position;
position =
parseRepeatedField(
message,
data,
position,
limit,
tag,
number,
wireType,
pos,
typeAndOffset,
fieldType,
fieldOffset,
registers);
if (position != oldPosition) {
continue;
}
} else if (fieldType == 50) {
if (wireType == WireFormat.WIRETYPE_LENGTH_DELIMITED) {
final int oldPosition = position;
position = parseMapField(message, data, position, limit, pos, fieldOffset, registers);
if (position != oldPosition) {
continue;
}
}
} else {
final int oldPosition = position;
position =
parseOneofField(
message,
data,
position,
limit,
tag,
number,
wireType,
typeAndOffset,
fieldType,
fieldOffset,
pos,
registers);
if (position != oldPosition) {
continue;
}
}
}
position = decodeUnknownField(
tag, data, position, limit, getMutableUnknownFields(message), registers);
}
if (currentPresenceFieldOffset != NO_PRESENCE_SENTINEL) {
unsafe.putInt(message, (long) currentPresenceFieldOffset, currentPresenceField);
}
if (position != limit) {
throw InvalidProtocolBufferException.parseFailure();
}
return position;
}
@Override @Override
public void mergeFrom(T message, byte[] data, int position, int limit, Registers registers) public void mergeFrom(T message, byte[] data, int position, int limit, Registers registers)
throws IOException { throws IOException {
switch (syntax) { parseMessage(message, data, position, limit, 0, registers);
case PROTO3:
parseProto3Message(message, data, position, limit, registers);
break;
case PROTO2:
parseProto2Message(message, data, position, limit, 0, registers);
break;
}
} }
@Override @Override
@ -4935,6 +4692,10 @@ final class MessageSchema<T> implements Schema<T> {
return (value & ENFORCE_UTF8_MASK) != 0; return (value & ENFORCE_UTF8_MASK) != 0;
} }
private static boolean isLegacyEnumIsClosed(int value) {
return (value & LEGACY_ENUM_IS_CLOSED_MASK) != 0;
}
private static long offset(int value) { private static long offset(int value) {
return value & OFFSET_MASK; return value & OFFSET_MASK;
} }

@ -372,6 +372,7 @@ inline bool ExposePublicParser(const FileDescriptor* descriptor) {
// but in the message and can be queried using additional getters that return // but in the message and can be queried using additional getters that return
// ints. // ints.
inline bool SupportUnknownEnumValue(const FieldDescriptor* field) { inline bool SupportUnknownEnumValue(const FieldDescriptor* field) {
// TODO(b/279034699): Check Java legacy_enum_field_treated_as_closed feature.
return !field->legacy_enum_field_treated_as_closed(); return !field->legacy_enum_field_treated_as_closed();
} }

Loading…
Cancel
Save