diff --git a/java/core/src/main/java/com/google/protobuf/CodedInputStream.java b/java/core/src/main/java/com/google/protobuf/CodedInputStream.java index 8f1ac736d6..80e0e2c2c2 100644 --- a/java/core/src/main/java/com/google/protobuf/CodedInputStream.java +++ b/java/core/src/main/java/com/google/protobuf/CodedInputStream.java @@ -2276,6 +2276,9 @@ public abstract class CodedInputStream { if (size == 0) { return ""; } + if (size < 0) { + throw InvalidProtocolBufferException.negativeSize(); + } if (size <= bufferSize) { refillBuffer(size); String result = new String(buffer, pos, size, UTF_8); @@ -2300,6 +2303,8 @@ public abstract class CodedInputStream { tempPos = oldPos; } else if (size == 0) { return ""; + } else if (size < 0) { + throw InvalidProtocolBufferException.negativeSize(); } else if (size <= bufferSize) { refillBuffer(size); bytes = buffer; @@ -2394,6 +2399,9 @@ public abstract class CodedInputStream { if (size == 0) { return ByteString.EMPTY; } + if (size < 0) { + throw InvalidProtocolBufferException.negativeSize(); + } return readBytesSlowPath(size); } @@ -2406,6 +2414,8 @@ public abstract class CodedInputStream { final byte[] result = Arrays.copyOfRange(buffer, pos, pos + size); pos += size; return result; + } else if (size < 0) { + throw InvalidProtocolBufferException.negativeSize(); } else { // Slow path: Build a byte array first then copy it. // TODO: Do we want to protect from malicious input streams here? @@ -2425,6 +2435,9 @@ public abstract class CodedInputStream { if (size == 0) { return Internal.EMPTY_BYTE_BUFFER; } + if (size < 0) { + throw InvalidProtocolBufferException.negativeSize(); + } // Slow path: Build a byte array first then copy it. // We must copy as the byte array was handed off to the InputStream and a malicious diff --git a/java/core/src/test/java/com/google/protobuf/CodedInputStreamTest.java b/java/core/src/test/java/com/google/protobuf/CodedInputStreamTest.java index 2de3273e34..cc85f96985 100644 --- a/java/core/src/test/java/com/google/protobuf/CodedInputStreamTest.java +++ b/java/core/src/test/java/com/google/protobuf/CodedInputStreamTest.java @@ -10,6 +10,7 @@ package com.google.protobuf; import static com.google.common.truth.Truth.assertThat; import static com.google.common.truth.Truth.assertWithMessage; import static org.junit.Assert.assertArrayEquals; +import static org.junit.Assert.assertThrows; import protobuf_unittest.UnittestProto.BoolMessage; import protobuf_unittest.UnittestProto.Int32Message; import protobuf_unittest.UnittestProto.Int64Message; @@ -534,6 +535,86 @@ public class CodedInputStreamTest { } } + @Test + public void testReadStringWithSizeOverflow_throwsInvalidProtocolBufferException() + throws Exception { + ByteString.Output rawOutput = ByteString.newOutput(); + CodedOutputStream output = CodedOutputStream.newInstance(rawOutput); + + output.writeUInt32NoTag(0xFFFFFFFF); // Larger than Integer.MAX_VALUE. + output.writeRawBytes(new byte[32]); // Pad with a few random bytes. + output.flush(); + byte[] data = rawOutput.toByteString().toByteArray(); + for (InputType inputType : InputType.values()) { + CodedInputStream input = inputType.newDecoder(data); + assertThrows(InvalidProtocolBufferException.class, input::readString); + } + } + + @Test + public void testReadStringRequireUtf8WithSizeOverflow_throwsInvalidProtocolBufferException() + throws Exception { + ByteString.Output rawOutput = ByteString.newOutput(); + CodedOutputStream output = CodedOutputStream.newInstance(rawOutput); + + output.writeUInt32NoTag(0xFFFFFFFF); // Larger than Integer.MAX_VALUE. + output.writeRawBytes(new byte[32]); // Pad with a few random bytes. + output.flush(); + byte[] data = rawOutput.toByteString().toByteArray(); + for (InputType inputType : InputType.values()) { + CodedInputStream input = inputType.newDecoder(data); + assertThrows(InvalidProtocolBufferException.class, input::readStringRequireUtf8); + } + } + + @Test + public void testReadBytesWithHugeSizeOverflow_throwsInvalidProtocolBufferException() + throws Exception { + ByteString.Output rawOutput = ByteString.newOutput(); + CodedOutputStream output = CodedOutputStream.newInstance(rawOutput); + + output.writeUInt32NoTag(0xFFFFFFFF); // Larger than Integer.MAX_VALUE. + output.writeRawBytes(new byte[32]); // Pad with a few random bytes. + output.flush(); + byte[] data = rawOutput.toByteString().toByteArray(); + for (InputType inputType : InputType.values()) { + CodedInputStream input = inputType.newDecoder(data); + assertThrows(InvalidProtocolBufferException.class, input::readBytes); + } + } + + @Test + public void testReadByteArrayWithHugeSizeOverflow_throwsInvalidProtocolBufferException() + throws Exception { + ByteString.Output rawOutput = ByteString.newOutput(); + CodedOutputStream output = CodedOutputStream.newInstance(rawOutput); + + output.writeUInt32NoTag(0xFFFFFFFF); // Larger than Integer.MAX_VALUE. + output.writeRawBytes(new byte[32]); // Pad with a few random bytes. + output.flush(); + byte[] data = rawOutput.toByteString().toByteArray(); + for (InputType inputType : InputType.values()) { + CodedInputStream input = inputType.newDecoder(data); + assertThrows(InvalidProtocolBufferException.class, input::readByteArray); + } + } + + @Test + public void testReadByteBufferWithSizeOverflow_throwsInvalidProtocolBufferException() + throws Exception { + ByteString.Output rawOutput = ByteString.newOutput(); + CodedOutputStream output = CodedOutputStream.newInstance(rawOutput); + + output.writeUInt32NoTag(0xFFFFFFFF); // Larger than Integer.MAX_VALUE. + output.writeRawBytes(new byte[32]); // Pad with a few random bytes. + output.flush(); + byte[] data = rawOutput.toByteString().toByteArray(); + for (InputType inputType : InputType.values()) { + CodedInputStream input = inputType.newDecoder(data); + assertThrows(InvalidProtocolBufferException.class, input::readByteBuffer); + } + } + /** * Test we can do messages that are up to CodedInputStream#DEFAULT_SIZE_LIMIT in size (2G or * Integer#MAX_SIZE).