Add recursion check when parsing unknown fields in Java.

PiperOrigin-RevId: 675657198
pull/18387/head
Protobuf Team Bot 2 months ago committed by Sandy Zhang
parent 850fcce917
commit 4728531c16
  1. 28
      java/core/src/main/java/com/google/protobuf/ArrayDecoders.java
  2. 6
      java/core/src/main/java/com/google/protobuf/CodedInputStream.java
  3. 12
      java/core/src/main/java/com/google/protobuf/MessageSchema.java
  4. 3
      java/core/src/main/java/com/google/protobuf/MessageSetSchema.java
  5. 29
      java/core/src/main/java/com/google/protobuf/UnknownFieldSchema.java
  6. 158
      java/core/src/test/java/com/google/protobuf/CodedInputStreamTest.java
  7. 232
      java/lite/src/test/java/com/google/protobuf/LiteTest.java

@ -23,6 +23,10 @@ import java.io.IOException;
*/
@CheckReturnValue
final class ArrayDecoders {
static final int DEFAULT_RECURSION_LIMIT = 100;
@SuppressWarnings("NonFinalStaticField")
private static volatile int recursionLimit = DEFAULT_RECURSION_LIMIT;
private ArrayDecoders() {}
@ -37,6 +41,7 @@ final class ArrayDecoders {
public long long1;
public Object object1;
public final ExtensionRegistryLite extensionRegistry;
public int recursionDepth;
Registers() {
this.extensionRegistry = ExtensionRegistryLite.getEmptyRegistry();
@ -244,7 +249,10 @@ final class ArrayDecoders {
if (length < 0 || length > limit - position) {
throw InvalidProtocolBufferException.truncatedMessage();
}
registers.recursionDepth++;
checkRecursionLimit(registers.recursionDepth);
schema.mergeFrom(msg, data, position, position + length, registers);
registers.recursionDepth--;
registers.object1 = msg;
return position + length;
}
@ -262,8 +270,11 @@ final class ArrayDecoders {
// A group field must has a MessageSchema (the only other subclass of Schema is MessageSetSchema
// and it can't be used in group fields).
final MessageSchema messageSchema = (MessageSchema) schema;
registers.recursionDepth++;
checkRecursionLimit(registers.recursionDepth);
final int endPosition =
messageSchema.parseMessage(msg, data, position, limit, endGroup, registers);
registers.recursionDepth--;
registers.object1 = msg;
return endPosition;
}
@ -1024,6 +1035,8 @@ final class ArrayDecoders {
final UnknownFieldSetLite child = UnknownFieldSetLite.newInstance();
final int endGroup = (tag & ~0x7) | WireFormat.WIRETYPE_END_GROUP;
int lastTag = 0;
registers.recursionDepth++;
checkRecursionLimit(registers.recursionDepth);
while (position < limit) {
position = decodeVarint32(data, position, registers);
lastTag = registers.int1;
@ -1032,6 +1045,7 @@ final class ArrayDecoders {
}
position = decodeUnknownField(lastTag, data, position, limit, child, registers);
}
registers.recursionDepth--;
if (position > limit || lastTag != endGroup) {
throw InvalidProtocolBufferException.parseFailure();
}
@ -1078,4 +1092,18 @@ final class ArrayDecoders {
throw InvalidProtocolBufferException.invalidTag();
}
}
/**
* Set the maximum recursion limit that ArrayDecoders will allow. An exception will be thrown if
* the depth of the message exceeds this limit.
*/
public static void setRecursionLimit(int limit) {
recursionLimit = limit;
}
private static void checkRecursionLimit(int depth) throws InvalidProtocolBufferException {
if (depth >= recursionLimit) {
throw InvalidProtocolBufferException.recursionLimitExceeded();
}
}
}

@ -229,7 +229,10 @@ public abstract class CodedInputStream {
if (tag == 0) {
return;
}
checkRecursionLimit();
++recursionDepth;
boolean fieldSkipped = skipField(tag);
--recursionDepth;
if (!fieldSkipped) {
return;
}
@ -246,7 +249,10 @@ public abstract class CodedInputStream {
if (tag == 0) {
return;
}
checkRecursionLimit();
++recursionDepth;
boolean fieldSkipped = skipField(tag, output);
--recursionDepth;
if (!fieldSkipped) {
return;
}

@ -3006,8 +3006,8 @@ final class MessageSchema<T> implements Schema<T> {
unknownFields = unknownFieldSchema.getBuilderFromMessage(message);
}
// Unknown field.
if (unknownFieldSchema.mergeOneFieldFrom(unknownFields, reader)) {
if (unknownFieldSchema.mergeOneFieldFrom(
unknownFields, reader, /* currentDepth= */ 0)) {
continue;
}
}
@ -3382,8 +3382,8 @@ final class MessageSchema<T> implements Schema<T> {
if (unknownFields == null) {
unknownFields = unknownFieldSchema.getBuilderFromMessage(message);
}
if (!unknownFieldSchema.mergeOneFieldFrom(unknownFields, reader)) {
if (!unknownFieldSchema.mergeOneFieldFrom(
unknownFields, reader, /* currentDepth= */ 0)) {
return;
}
break;
@ -3399,8 +3399,8 @@ final class MessageSchema<T> implements Schema<T> {
if (unknownFields == null) {
unknownFields = unknownFieldSchema.getBuilderFromMessage(message);
}
if (!unknownFieldSchema.mergeOneFieldFrom(unknownFields, reader)) {
if (!unknownFieldSchema.mergeOneFieldFrom(
unknownFields, reader, /* currentDepth= */ 0)) {
return;
}
}

@ -278,8 +278,7 @@ final class MessageSetSchema<T> implements Schema<T> {
reader, extension, extensionRegistry, extensions);
return true;
} else {
return unknownFieldSchema.mergeOneFieldFrom(unknownFields, reader);
return unknownFieldSchema.mergeOneFieldFrom(unknownFields, reader, /* currentDepth= */ 0);
}
} else {
return reader.skipField();

@ -13,6 +13,11 @@ import java.io.IOException;
@CheckReturnValue
abstract class UnknownFieldSchema<T, B> {
static final int DEFAULT_RECURSION_LIMIT = 100;
@SuppressWarnings("NonFinalStaticField")
private static volatile int recursionLimit = DEFAULT_RECURSION_LIMIT;
/** Whether unknown fields should be dropped. */
abstract boolean shouldDiscardUnknownFields(Reader reader);
@ -55,7 +60,9 @@ abstract class UnknownFieldSchema<T, B> {
/** Marks unknown fields as immutable. */
abstract void makeImmutable(Object message);
final boolean mergeOneFieldFrom(B unknownFields, Reader reader) throws IOException {
/** Merges one field into the unknown fields. */
final boolean mergeOneFieldFrom(B unknownFields, Reader reader, int currentDepth)
throws IOException {
int tag = reader.getTag();
int fieldNumber = WireFormat.getTagFieldNumber(tag);
switch (WireFormat.getTagWireType(tag)) {
@ -74,7 +81,12 @@ abstract class UnknownFieldSchema<T, B> {
case WireFormat.WIRETYPE_START_GROUP:
final B subFields = newBuilder();
int endGroupTag = WireFormat.makeTag(fieldNumber, WireFormat.WIRETYPE_END_GROUP);
mergeFrom(subFields, reader);
currentDepth++;
if (currentDepth >= recursionLimit) {
throw InvalidProtocolBufferException.recursionLimitExceeded();
}
mergeFrom(subFields, reader, currentDepth);
currentDepth--;
if (endGroupTag != reader.getTag()) {
throw InvalidProtocolBufferException.invalidEndTag();
}
@ -87,10 +99,11 @@ abstract class UnknownFieldSchema<T, B> {
}
}
private final void mergeFrom(B unknownFields, Reader reader) throws IOException {
private final void mergeFrom(B unknownFields, Reader reader, int currentDepth)
throws IOException {
while (true) {
if (reader.getFieldNumber() == Reader.READ_DONE
|| !mergeOneFieldFrom(unknownFields, reader)) {
|| !mergeOneFieldFrom(unknownFields, reader, currentDepth)) {
break;
}
}
@ -107,4 +120,12 @@ abstract class UnknownFieldSchema<T, B> {
abstract int getSerializedSizeAsMessageSet(T message);
abstract int getSerializedSize(T unknowns);
/**
* Set the maximum recursion limit that ArrayDecoders will allow. An exception will be thrown if
* the depth of the message exceeds this limit.
*/
public void setRecursionLimit(int limit) {
recursionLimit = limit;
}
}

@ -11,6 +11,9 @@ import static com.google.common.truth.Truth.assertThat;
import static com.google.common.truth.Truth.assertWithMessage;
import static org.junit.Assert.assertArrayEquals;
import static org.junit.Assert.assertThrows;
import com.google.common.primitives.Bytes;
import map_test.MapTestProto.MapContainer;
import protobuf_unittest.UnittestProto.BoolMessage;
import protobuf_unittest.UnittestProto.Int32Message;
import protobuf_unittest.UnittestProto.Int64Message;
@ -35,6 +38,13 @@ public class CodedInputStreamTest {
private static final int DEFAULT_BLOCK_SIZE = 4096;
private static final int GROUP_TAP = WireFormat.makeTag(3, WireFormat.WIRETYPE_START_GROUP);
private static final byte[] NESTING_SGROUP = generateSGroupTags();
private static final byte[] NESTING_SGROUP_WITH_INITIAL_BYTES = generateSGroupTagsForMapField();
private enum InputType {
ARRAY {
@Override
@ -117,6 +127,17 @@ public class CodedInputStreamTest {
return bytes;
}
private static byte[] generateSGroupTags() {
byte[] bytes = new byte[100000];
Arrays.fill(bytes, (byte) GROUP_TAP);
return bytes;
}
private static byte[] generateSGroupTagsForMapField() {
byte[] initialBytes = {18, 1, 75, 26, (byte) 198, (byte) 154, 12};
return Bytes.concat(initialBytes, NESTING_SGROUP);
}
/**
* An InputStream which limits the number of bytes it reads at a time. We use this to make sure
* that CodedInputStream doesn't screw up when reading in small blocks.
@ -740,6 +761,143 @@ public class CodedInputStreamTest {
}
}
@Test
public void testMaliciousRecursion_unknownFields() throws Exception {
Throwable thrown =
assertThrows(
InvalidProtocolBufferException.class,
() -> TestRecursiveMessage.parseFrom(NESTING_SGROUP));
assertThat(thrown).hasMessageThat().contains("Protocol message had too many levels of nesting");
}
@Test
public void testMaliciousRecursion_skippingUnknownField() throws Exception {
Throwable thrown =
assertThrows(
InvalidProtocolBufferException.class,
() ->
DiscardUnknownFieldsParser.wrap(TestRecursiveMessage.parser())
.parseFrom(NESTING_SGROUP));
assertThat(thrown).hasMessageThat().contains("Protocol message had too many levels of nesting");
}
@Test
public void testMaliciousSGroupTagsWithMapField_fromInputStream() throws Exception {
Throwable parseFromThrown =
assertThrows(
InvalidProtocolBufferException.class,
() ->
MapContainer.parseFrom(
new ByteArrayInputStream(NESTING_SGROUP_WITH_INITIAL_BYTES)));
Throwable mergeFromThrown =
assertThrows(
InvalidProtocolBufferException.class,
() ->
MapContainer.newBuilder()
.mergeFrom(new ByteArrayInputStream(NESTING_SGROUP_WITH_INITIAL_BYTES)));
assertThat(parseFromThrown)
.hasMessageThat()
.contains("Protocol message had too many levels of nesting");
assertThat(mergeFromThrown)
.hasMessageThat()
.contains("Protocol message had too many levels of nesting");
}
@Test
public void testMaliciousSGroupTags_inputStream_skipMessage() throws Exception {
ByteArrayInputStream inputSteam = new ByteArrayInputStream(NESTING_SGROUP);
CodedInputStream input = CodedInputStream.newInstance(inputSteam);
CodedOutputStream output = CodedOutputStream.newInstance(new byte[NESTING_SGROUP.length]);
Throwable thrown = assertThrows(InvalidProtocolBufferException.class, input::skipMessage);
Throwable thrown2 =
assertThrows(InvalidProtocolBufferException.class, () -> input.skipMessage(output));
assertThat(thrown).hasMessageThat().contains("Protocol message had too many levels of nesting");
assertThat(thrown2)
.hasMessageThat()
.contains("Protocol message had too many levels of nesting");
}
@Test
public void testMaliciousSGroupTagsWithMapField_fromByteArray() throws Exception {
Throwable parseFromThrown =
assertThrows(
InvalidProtocolBufferException.class,
() -> MapContainer.parseFrom(NESTING_SGROUP_WITH_INITIAL_BYTES));
Throwable mergeFromThrown =
assertThrows(
InvalidProtocolBufferException.class,
() -> MapContainer.newBuilder().mergeFrom(NESTING_SGROUP_WITH_INITIAL_BYTES));
assertThat(parseFromThrown)
.hasMessageThat()
.contains("the input ended unexpectedly in the middle of a field");
assertThat(mergeFromThrown)
.hasMessageThat()
.contains("the input ended unexpectedly in the middle of a field");
}
@Test
public void testMaliciousSGroupTags_arrayDecoder_skipMessage() throws Exception {
CodedInputStream input = CodedInputStream.newInstance(NESTING_SGROUP);
CodedOutputStream output = CodedOutputStream.newInstance(new byte[NESTING_SGROUP.length]);
Throwable thrown = assertThrows(InvalidProtocolBufferException.class, input::skipMessage);
Throwable thrown2 =
assertThrows(InvalidProtocolBufferException.class, () -> input.skipMessage(output));
assertThat(thrown).hasMessageThat().contains("Protocol message had too many levels of nesting");
assertThat(thrown2)
.hasMessageThat()
.contains("Protocol message had too many levels of nesting");
}
@Test
public void testMaliciousSGroupTagsWithMapField_fromByteBuffer() throws Exception {
Throwable thrown =
assertThrows(
InvalidProtocolBufferException.class,
() -> MapContainer.parseFrom(ByteBuffer.wrap(NESTING_SGROUP_WITH_INITIAL_BYTES)));
assertThat(thrown)
.hasMessageThat()
.contains("the input ended unexpectedly in the middle of a field");
}
@Test
public void testMaliciousSGroupTags_byteBuffer_skipMessage() throws Exception {
CodedInputStream input = InputType.NIO_DIRECT.newDecoder(NESTING_SGROUP);
CodedOutputStream output = CodedOutputStream.newInstance(new byte[NESTING_SGROUP.length]);
Throwable thrown = assertThrows(InvalidProtocolBufferException.class, input::skipMessage);
Throwable thrown2 =
assertThrows(InvalidProtocolBufferException.class, () -> input.skipMessage(output));
assertThat(thrown).hasMessageThat().contains("Protocol message had too many levels of nesting");
assertThat(thrown2)
.hasMessageThat()
.contains("Protocol message had too many levels of nesting");
}
@Test
public void testMaliciousSGroupTags_iterableByteBuffer() throws Exception {
CodedInputStream input = InputType.ITER_DIRECT.newDecoder(NESTING_SGROUP);
CodedOutputStream output = CodedOutputStream.newInstance(new byte[NESTING_SGROUP.length]);
Throwable thrown = assertThrows(InvalidProtocolBufferException.class, input::skipMessage);
Throwable thrown2 =
assertThrows(InvalidProtocolBufferException.class, () -> input.skipMessage(output));
assertThat(thrown).hasMessageThat().contains("Protocol message had too many levels of nesting");
assertThat(thrown2)
.hasMessageThat()
.contains("Protocol message had too many levels of nesting");
}
private void checkSizeLimitExceeded(InvalidProtocolBufferException e) {
assertThat(e)
.hasMessageThat()

@ -2463,6 +2463,211 @@ public class LiteTest {
}
}
@Test
public void testParseFromInputStream_concurrent_nestingUnknownGroups() throws Exception {
int numThreads = 200;
ArrayList<Thread> threads = new ArrayList<>();
ByteString byteString = generateNestingGroups(99);
AtomicBoolean thrown = new AtomicBoolean(false);
for (int i = 0; i < numThreads; i++) {
Thread thread =
new Thread(
() -> {
try {
TestAllTypesLite unused = TestAllTypesLite.parseFrom(byteString);
} catch (IOException e) {
if (e.getMessage().contains("Protocol message had too many levels of nesting")) {
thrown.set(true);
}
}
});
thread.start();
threads.add(thread);
}
for (Thread thread : threads) {
thread.join();
}
assertThat(thrown.get()).isFalse();
}
@Test
public void testParseFromInputStream_nestingUnknownGroups() throws IOException {
ByteString byteString = generateNestingGroups(99);
Throwable thrown =
assertThrows(
InvalidProtocolBufferException.class, () -> TestAllTypesLite.parseFrom(byteString));
assertThat(thrown)
.hasMessageThat()
.doesNotContain("Protocol message had too many levels of nesting");
}
@Test
public void testParseFromInputStream_nestingUnknownGroups_exception() throws IOException {
ByteString byteString = generateNestingGroups(100);
Throwable thrown =
assertThrows(
InvalidProtocolBufferException.class, () -> TestAllTypesLite.parseFrom(byteString));
assertThat(thrown).hasMessageThat().contains("Protocol message had too many levels of nesting");
}
@Test
public void testParseFromInputStream_setRecursionLimit_exception() throws IOException {
ByteString byteString = generateNestingGroups(199);
UnknownFieldSchema<?, ?> schema = SchemaUtil.unknownFieldSetLiteSchema();
schema.setRecursionLimit(200);
Throwable thrown =
assertThrows(
InvalidProtocolBufferException.class, () -> TestAllTypesLite.parseFrom(byteString));
assertThat(thrown)
.hasMessageThat()
.doesNotContain("Protocol message had too many levels of nesting");
schema.setRecursionLimit(UnknownFieldSchema.DEFAULT_RECURSION_LIMIT);
}
@Test
public void testParseFromBytes_concurrent_nestingUnknownGroups() throws Exception {
int numThreads = 200;
ArrayList<Thread> threads = new ArrayList<>();
ByteString byteString = generateNestingGroups(99);
AtomicBoolean thrown = new AtomicBoolean(false);
for (int i = 0; i < numThreads; i++) {
Thread thread =
new Thread(
() -> {
try {
// Should pass in byte[] instead of ByteString to go into ArrayDecoders.
TestAllTypesLite unused = TestAllTypesLite.parseFrom(byteString.toByteArray());
} catch (InvalidProtocolBufferException e) {
if (e.getMessage().contains("Protocol message had too many levels of nesting")) {
thrown.set(true);
}
}
});
thread.start();
threads.add(thread);
}
for (Thread thread : threads) {
thread.join();
}
assertThat(thrown.get()).isFalse();
}
@Test
public void testParseFromBytes_nestingUnknownGroups() throws IOException {
ByteString byteString = generateNestingGroups(99);
Throwable thrown =
assertThrows(
InvalidProtocolBufferException.class,
() -> TestAllTypesLite.parseFrom(byteString.toByteArray()));
assertThat(thrown)
.hasMessageThat()
.doesNotContain("Protocol message had too many levels of nesting");
}
@Test
public void testParseFromBytes_nestingUnknownGroups_exception() throws IOException {
ByteString byteString = generateNestingGroups(100);
Throwable thrown =
assertThrows(
InvalidProtocolBufferException.class,
() -> TestAllTypesLite.parseFrom(byteString.toByteArray()));
assertThat(thrown).hasMessageThat().contains("Protocol message had too many levels of nesting");
}
@Test
public void testParseFromBytes_setRecursionLimit_exception() throws IOException {
ByteString byteString = generateNestingGroups(199);
ArrayDecoders.setRecursionLimit(200);
Throwable thrown =
assertThrows(
InvalidProtocolBufferException.class,
() -> TestAllTypesLite.parseFrom(byteString.toByteArray()));
assertThat(thrown)
.hasMessageThat()
.doesNotContain("Protocol message had too many levels of nesting");
ArrayDecoders.setRecursionLimit(ArrayDecoders.DEFAULT_RECURSION_LIMIT);
}
@Test
public void testParseFromBytes_recursiveMessages() throws Exception {
byte[] data99 = makeRecursiveMessage(99).toByteArray();
byte[] data100 = makeRecursiveMessage(100).toByteArray();
RecursiveMessage unused = RecursiveMessage.parseFrom(data99);
Throwable thrown =
assertThrows(
InvalidProtocolBufferException.class, () -> RecursiveMessage.parseFrom(data100));
assertThat(thrown).hasMessageThat().contains("Protocol message had too many levels of nesting");
}
@Test
public void testParseFromBytes_recursiveKnownGroups() throws Exception {
byte[] data99 = makeRecursiveGroup(99).toByteArray();
byte[] data100 = makeRecursiveGroup(100).toByteArray();
RecursiveGroup unused = RecursiveGroup.parseFrom(data99);
Throwable thrown =
assertThrows(InvalidProtocolBufferException.class, () -> RecursiveGroup.parseFrom(data100));
assertThat(thrown).hasMessageThat().contains("Protocol message had too many levels of nesting");
}
@Test
@SuppressWarnings("ProtoParseFromByteString")
public void testMaliciousSGroupTagsWithMapField_fromByteArray() throws Exception {
ByteString byteString = generateNestingGroups(102);
Throwable parseFromThrown =
assertThrows(
InvalidProtocolBufferException.class,
() -> MapContainer.parseFrom(byteString.toByteArray()));
Throwable mergeFromThrown =
assertThrows(
InvalidProtocolBufferException.class,
() -> MapContainer.newBuilder().mergeFrom(byteString.toByteArray()));
assertThat(parseFromThrown)
.hasMessageThat()
.contains("Protocol message had too many levels of nesting");
assertThat(mergeFromThrown)
.hasMessageThat()
.contains("Protocol message had too many levels of nesting");
}
@Test
public void testMaliciousSGroupTagsWithMapField_fromInputStream() throws Exception {
byte[] bytes = generateNestingGroups(101).toByteArray();
Throwable parseFromThrown =
assertThrows(
InvalidProtocolBufferException.class,
() -> MapContainer.parseFrom(new ByteArrayInputStream(bytes)));
Throwable mergeFromThrown =
assertThrows(
InvalidProtocolBufferException.class,
() -> MapContainer.newBuilder().mergeFrom(new ByteArrayInputStream(bytes)));
assertThat(parseFromThrown)
.hasMessageThat()
.contains("Protocol message had too many levels of nesting");
assertThat(mergeFromThrown)
.hasMessageThat()
.contains("Protocol message had too many levels of nesting");
}
@Test
public void testParseFromByteBuffer_extensions() throws Exception {
TestAllExtensionsLite message =
@ -2819,4 +3024,31 @@ public class LiteTest {
}
return false;
}
private static ByteString generateNestingGroups(int num) throws IOException {
int groupTap = WireFormat.makeTag(3, WireFormat.WIRETYPE_START_GROUP);
ByteString.Output byteStringOutput = ByteString.newOutput();
CodedOutputStream codedOutput = CodedOutputStream.newInstance(byteStringOutput);
for (int i = 0; i < num; i++) {
codedOutput.writeInt32NoTag(groupTap);
}
codedOutput.flush();
return byteStringOutput.toByteString();
}
private static RecursiveMessage makeRecursiveMessage(int num) {
if (num == 0) {
return RecursiveMessage.getDefaultInstance();
} else {
return RecursiveMessage.newBuilder().setRecurse(makeRecursiveMessage(num - 1)).build();
}
}
private static RecursiveGroup makeRecursiveGroup(int num) {
if (num == 0) {
return RecursiveGroup.getDefaultInstance();
} else {
return RecursiveGroup.newBuilder().setRecurse(makeRecursiveGroup(num - 1)).build();
}
}
}

Loading…
Cancel
Save