From a94f57bd69e9a5999ba67736e4d70a9d7f96aaf5 Mon Sep 17 00:00:00 2001 From: Protobuf Team Bot Date: Tue, 9 Apr 2024 22:58:08 -0700 Subject: [PATCH] Check that size is non-negative when reading string or bytes in StreamDecoder. This ensures that StreamDecoder throws a InvalidProtocolBufferException instead of an IllegalStateException on some invalid input. All other implementations of CodedInputStream already do this check. PiperOrigin-RevId: 623383287 --- .../com/google/protobuf/CodedInputStream.java | 13 +++ .../google/protobuf/CodedInputStreamTest.java | 81 +++++++++++++++++++ 2 files changed, 94 insertions(+) diff --git a/java/core/src/main/java/com/google/protobuf/CodedInputStream.java b/java/core/src/main/java/com/google/protobuf/CodedInputStream.java index 224ced5292..81da417783 100644 --- a/java/core/src/main/java/com/google/protobuf/CodedInputStream.java +++ b/java/core/src/main/java/com/google/protobuf/CodedInputStream.java @@ -2278,6 +2278,9 @@ public abstract class CodedInputStream { if (size == 0) { return ""; } + if (size < 0) { + throw InvalidProtocolBufferException.negativeSize(); + } if (size <= bufferSize) { refillBuffer(size); String result = new String(buffer, pos, size, UTF_8); @@ -2302,6 +2305,8 @@ public abstract class CodedInputStream { tempPos = oldPos; } else if (size == 0) { return ""; + } else if (size < 0) { + throw InvalidProtocolBufferException.negativeSize(); } else if (size <= bufferSize) { refillBuffer(size); bytes = buffer; @@ -2396,6 +2401,9 @@ public abstract class CodedInputStream { if (size == 0) { return ByteString.EMPTY; } + if (size < 0) { + throw InvalidProtocolBufferException.negativeSize(); + } return readBytesSlowPath(size); } @@ -2408,6 +2416,8 @@ public abstract class CodedInputStream { final byte[] result = Arrays.copyOfRange(buffer, pos, pos + size); pos += size; return result; + } else if (size < 0) { + throw InvalidProtocolBufferException.negativeSize(); } else { // Slow path: Build a byte array first then copy it. // TODO: Do we want to protect from malicious input streams here? @@ -2427,6 +2437,9 @@ public abstract class CodedInputStream { if (size == 0) { return Internal.EMPTY_BYTE_BUFFER; } + if (size < 0) { + throw InvalidProtocolBufferException.negativeSize(); + } // Slow path: Build a byte array first then copy it. // We must copy as the byte array was handed off to the InputStream and a malicious diff --git a/java/core/src/test/java/com/google/protobuf/CodedInputStreamTest.java b/java/core/src/test/java/com/google/protobuf/CodedInputStreamTest.java index 5162888d07..ff700587a1 100644 --- a/java/core/src/test/java/com/google/protobuf/CodedInputStreamTest.java +++ b/java/core/src/test/java/com/google/protobuf/CodedInputStreamTest.java @@ -10,6 +10,7 @@ package com.google.protobuf; import static com.google.common.truth.Truth.assertThat; import static com.google.common.truth.Truth.assertWithMessage; import static org.junit.Assert.assertArrayEquals; +import static org.junit.Assert.assertThrows; import protobuf_unittest.UnittestProto.BoolMessage; import protobuf_unittest.UnittestProto.Int32Message; import protobuf_unittest.UnittestProto.Int64Message; @@ -534,6 +535,86 @@ public class CodedInputStreamTest { } } + @Test + public void testReadStringWithSizeOverflow_throwsInvalidProtocolBufferException() + throws Exception { + ByteString.Output rawOutput = ByteString.newOutput(); + CodedOutputStream output = CodedOutputStream.newInstance(rawOutput); + + output.writeUInt32NoTag(0xFFFFFFFF); // Larger than Integer.MAX_VALUE. + output.writeRawBytes(new byte[32]); // Pad with a few random bytes. + output.flush(); + byte[] data = rawOutput.toByteString().toByteArray(); + for (InputType inputType : InputType.values()) { + CodedInputStream input = inputType.newDecoder(data); + assertThrows(InvalidProtocolBufferException.class, input::readString); + } + } + + @Test + public void testReadStringRequireUtf8WithSizeOverflow_throwsInvalidProtocolBufferException() + throws Exception { + ByteString.Output rawOutput = ByteString.newOutput(); + CodedOutputStream output = CodedOutputStream.newInstance(rawOutput); + + output.writeUInt32NoTag(0xFFFFFFFF); // Larger than Integer.MAX_VALUE. + output.writeRawBytes(new byte[32]); // Pad with a few random bytes. + output.flush(); + byte[] data = rawOutput.toByteString().toByteArray(); + for (InputType inputType : InputType.values()) { + CodedInputStream input = inputType.newDecoder(data); + assertThrows(InvalidProtocolBufferException.class, input::readStringRequireUtf8); + } + } + + @Test + public void testReadBytesWithHugeSizeOverflow_throwsInvalidProtocolBufferException() + throws Exception { + ByteString.Output rawOutput = ByteString.newOutput(); + CodedOutputStream output = CodedOutputStream.newInstance(rawOutput); + + output.writeUInt32NoTag(0xFFFFFFFF); // Larger than Integer.MAX_VALUE. + output.writeRawBytes(new byte[32]); // Pad with a few random bytes. + output.flush(); + byte[] data = rawOutput.toByteString().toByteArray(); + for (InputType inputType : InputType.values()) { + CodedInputStream input = inputType.newDecoder(data); + assertThrows(InvalidProtocolBufferException.class, input::readBytes); + } + } + + @Test + public void testReadByteArrayWithHugeSizeOverflow_throwsInvalidProtocolBufferException() + throws Exception { + ByteString.Output rawOutput = ByteString.newOutput(); + CodedOutputStream output = CodedOutputStream.newInstance(rawOutput); + + output.writeUInt32NoTag(0xFFFFFFFF); // Larger than Integer.MAX_VALUE. + output.writeRawBytes(new byte[32]); // Pad with a few random bytes. + output.flush(); + byte[] data = rawOutput.toByteString().toByteArray(); + for (InputType inputType : InputType.values()) { + CodedInputStream input = inputType.newDecoder(data); + assertThrows(InvalidProtocolBufferException.class, input::readByteArray); + } + } + + @Test + public void testReadByteBufferWithSizeOverflow_throwsInvalidProtocolBufferException() + throws Exception { + ByteString.Output rawOutput = ByteString.newOutput(); + CodedOutputStream output = CodedOutputStream.newInstance(rawOutput); + + output.writeUInt32NoTag(0xFFFFFFFF); // Larger than Integer.MAX_VALUE. + output.writeRawBytes(new byte[32]); // Pad with a few random bytes. + output.flush(); + byte[] data = rawOutput.toByteString().toByteArray(); + for (InputType inputType : InputType.values()) { + CodedInputStream input = inputType.newDecoder(data); + assertThrows(InvalidProtocolBufferException.class, input::readByteBuffer); + } + } + /** * Test we can do messages that are up to CodedInputStream#DEFAULT_SIZE_LIMIT in size (2G or * Integer#MAX_SIZE).