diff --git a/javanano/src/main/java/com/google/protobuf/nano/CodedOutputByteBufferNano.java b/javanano/src/main/java/com/google/protobuf/nano/CodedOutputByteBufferNano.java index 3590718308..d0bc07e430 100644 --- a/javanano/src/main/java/com/google/protobuf/nano/CodedOutputByteBufferNano.java +++ b/javanano/src/main/java/com/google/protobuf/nano/CodedOutputByteBufferNano.java @@ -31,6 +31,9 @@ package com.google.protobuf.nano; import java.io.IOException; +import java.nio.BufferOverflowException; +import java.nio.ByteBuffer; +import java.nio.ReadOnlyBufferException; /** * Encodes and writes protocol message fields. @@ -47,15 +50,17 @@ import java.io.IOException; * @author kneton@google.com Kenton Varda */ public final class CodedOutputByteBufferNano { - private final byte[] buffer; - private final int limit; - private int position; + /* max bytes per java UTF-16 char in UTF-8 */ + private static final int MAX_UTF8_EXPANSION = 3; + private final ByteBuffer buffer; private CodedOutputByteBufferNano(final byte[] buffer, final int offset, final int length) { + this(ByteBuffer.wrap(buffer, offset, length)); + } + + private CodedOutputByteBufferNano(final ByteBuffer buffer) { this.buffer = buffer; - position = offset; - limit = offset + length; } /** @@ -287,14 +292,204 @@ public final class CodedOutputByteBufferNano { /** Write a {@code string} field to the stream. */ public void writeStringNoTag(final String value) throws IOException { - // Unfortunately there does not appear to be any way to tell Java to encode - // UTF-8 directly into our buffer, so we have to let it create its own byte - // array and then copy. - final byte[] bytes = value.getBytes(InternalNano.UTF_8); - writeRawVarint32(bytes.length); - writeRawBytes(bytes); + // UTF-8 byte length of the string is at least its UTF-16 code unit length (value.length()), + // and at most 3 times of it. Optimize for the case where we know this length results in a + // constant varint length - saves measuring length of the string. + try { + final int minLengthVarIntSize = computeRawVarint32Size(value.length()); + final int maxLengthVarIntSize = computeRawVarint32Size(value.length() * MAX_UTF8_EXPANSION); + if (minLengthVarIntSize == maxLengthVarIntSize) { + int oldPosition = buffer.position(); + buffer.position(oldPosition + minLengthVarIntSize); + encode(value, buffer); + int newPosition = buffer.position(); + buffer.position(oldPosition); + writeRawVarint32(newPosition - oldPosition - minLengthVarIntSize); + buffer.position(newPosition); + } else { + writeRawVarint32(encodedLength(value)); + encode(value, buffer); + } + } catch (BufferOverflowException e) { + throw new OutOfSpaceException(buffer.position(), buffer.limit()); + } + } + + // These UTF-8 handling methods are copied from Guava's Utf8 class. + /** + * Returns the number of bytes in the UTF-8-encoded form of {@code sequence}. For a string, + * this method is equivalent to {@code string.getBytes(UTF_8).length}, but is more efficient in + * both time and space. + * + * @throws IllegalArgumentException if {@code sequence} contains ill-formed UTF-16 (unpaired + * surrogates) + */ + private static int encodedLength(CharSequence sequence) { + // Warning to maintainers: this implementation is highly optimized. + int utf16Length = sequence.length(); + int utf8Length = utf16Length; + int i = 0; + + // This loop optimizes for pure ASCII. + while (i < utf16Length && sequence.charAt(i) < 0x80) { + i++; + } + + // This loop optimizes for chars less than 0x800. + for (; i < utf16Length; i++) { + char c = sequence.charAt(i); + if (c < 0x800) { + utf8Length += ((0x7f - c) >>> 31); // branch free! + } else { + utf8Length += encodedLengthGeneral(sequence, i); + break; + } + } + + if (utf8Length < utf16Length) { + // Necessary and sufficient condition for overflow because of maximum 3x expansion + throw new IllegalArgumentException("UTF-8 length does not fit in int: " + + (utf8Length + (1L << 32))); + } + return utf8Length; + } + + private static int encodedLengthGeneral(CharSequence sequence, int start) { + int utf16Length = sequence.length(); + int utf8Length = 0; + for (int i = start; i < utf16Length; i++) { + char c = sequence.charAt(i); + if (c < 0x800) { + utf8Length += (0x7f - c) >>> 31; // branch free! + } else { + utf8Length += 2; + // jdk7+: if (Character.isSurrogate(c)) { + if (Character.MIN_SURROGATE <= c && c <= Character.MAX_SURROGATE) { + // Check that we have a well-formed surrogate pair. + int cp = Character.codePointAt(sequence, i); + if (cp < Character.MIN_SUPPLEMENTARY_CODE_POINT) { + throw new IllegalArgumentException("Unpaired surrogate at index " + i); + } + i++; + } + } + } + return utf8Length; } + /** + * Encodes {@code sequence} into UTF-8, in {@code byteBuffer}. For a string, this method is + * equivalent to {@code buffer.put(string.getBytes(UTF_8))}, but is more efficient in both time + * and space. Bytes are written starting at the current position. This method requires paired + * surrogates, and therefore does not support chunking. + * + *

To ensure sufficient space in the output buffer, either call {@link #encodedLength} to + * compute the exact amount needed, or leave room for {@code 3 * sequence.length()}, which is the + * largest possible number of bytes that any input can be encoded to. + * + * @throws IllegalArgumentException if {@code sequence} contains ill-formed UTF-16 (unpaired + * surrogates) + * @throws BufferOverflowException if {@code sequence} encoded in UTF-8 does not fit in + * {@code byteBuffer}'s remaining space. + * @throws ReadOnlyBufferException if {@code byteBuffer} is a read-only buffer. + */ + private static void encode(CharSequence sequence, ByteBuffer byteBuffer) { + if (byteBuffer.isReadOnly()) { + throw new ReadOnlyBufferException(); + } else if (byteBuffer.hasArray()) { + try { + int encoded = encode(sequence, + byteBuffer.array(), + byteBuffer.arrayOffset() + byteBuffer.position(), + byteBuffer.remaining()); + byteBuffer.position(encoded - byteBuffer.arrayOffset()); + } catch (ArrayIndexOutOfBoundsException e) { + BufferOverflowException boe = new BufferOverflowException(); + boe.initCause(e); + throw boe; + } + } else { + encodeDirect(sequence, byteBuffer); + } + } + + private static void encodeDirect(CharSequence sequence, ByteBuffer byteBuffer) { + int utf16Length = sequence.length(); + for (int i = 0; i < utf16Length; i++) { + final char c = sequence.charAt(i); + if (c < 0x80) { // ASCII + byteBuffer.put((byte) c); + } else if (c < 0x800) { // 11 bits, two UTF-8 bytes + byteBuffer.put((byte) ((0xF << 6) | (c >>> 6))); + byteBuffer.put((byte) (0x80 | (0x3F & c))); + } else if (c < Character.MIN_SURROGATE || Character.MAX_SURROGATE < c) { + // Maximium single-char code point is 0xFFFF, 16 bits, three UTF-8 bytes + byteBuffer.put((byte) ((0xF << 5) | (c >>> 12))); + byteBuffer.put((byte) (0x80 | (0x3F & (c >>> 6)))); + byteBuffer.put((byte) (0x80 | (0x3F & c))); + } else { + final char low; + if (i + 1 == sequence.length() + || !Character.isSurrogatePair(c, (low = sequence.charAt(++i)))) { + throw new IllegalArgumentException("Unpaired surrogate at index " + (i - 1)); + } + int codePoint = Character.toCodePoint(c, low); + byteBuffer.put((byte) ((0xF << 4) | (codePoint >>> 18))); + byteBuffer.put((byte) (0x80 | (0x3F & (codePoint >>> 12)))); + byteBuffer.put((byte) (0x80 | (0x3F & (codePoint >>> 6)))); + byteBuffer.put((byte) (0x80 | (0x3F & codePoint))); + } + } + } + + private static int encode(CharSequence sequence, byte[] bytes, int offset, int length) { + int utf16Length = sequence.length(); + int j = offset; + int i = 0; + int limit = offset + length; + // Designed to take advantage of + // https://wikis.oracle.com/display/HotSpotInternals/RangeCheckElimination + for (char c; i < utf16Length && i + j < limit && (c = sequence.charAt(i)) < 0x80; i++) { + bytes[j + i] = (byte) c; + } + if (i == utf16Length) { + return j + utf16Length; + } + j += i; + for (char c; i < utf16Length; i++) { + c = sequence.charAt(i); + if (c < 0x80 && j < limit) { + bytes[j++] = (byte) c; + } else if (c < 0x800 && j <= limit - 2) { // 11 bits, two UTF-8 bytes + bytes[j++] = (byte) ((0xF << 6) | (c >>> 6)); + bytes[j++] = (byte) (0x80 | (0x3F & c)); + } else if ((c < Character.MIN_SURROGATE || Character.MAX_SURROGATE < c) && j <= limit - 3) { + // Maximum single-char code point is 0xFFFF, 16 bits, three UTF-8 bytes + bytes[j++] = (byte) ((0xF << 5) | (c >>> 12)); + bytes[j++] = (byte) (0x80 | (0x3F & (c >>> 6))); + bytes[j++] = (byte) (0x80 | (0x3F & c)); + } else if (j <= limit - 4) { + // Minimum code point represented by a surrogate pair is 0x10000, 17 bits, four UTF-8 bytes + final char low; + if (i + 1 == sequence.length() + || !Character.isSurrogatePair(c, (low = sequence.charAt(++i)))) { + throw new IllegalArgumentException("Unpaired surrogate at index " + (i - 1)); + } + int codePoint = Character.toCodePoint(c, low); + bytes[j++] = (byte) ((0xF << 4) | (codePoint >>> 18)); + bytes[j++] = (byte) (0x80 | (0x3F & (codePoint >>> 12))); + bytes[j++] = (byte) (0x80 | (0x3F & (codePoint >>> 6))); + bytes[j++] = (byte) (0x80 | (0x3F & codePoint)); + } else { + throw new ArrayIndexOutOfBoundsException("Failed writing " + c + " at index " + j); + } + } + return j; + } + + // End guava UTF-8 methods + + /** Write a {@code group} field to the stream. */ public void writeGroupNoTag(final MessageNano value) throws IOException { value.writeTo(this); @@ -602,9 +797,8 @@ public final class CodedOutputByteBufferNano { * {@code string} field. */ public static int computeStringSizeNoTag(final String value) { - final byte[] bytes = value.getBytes(InternalNano.UTF_8); - return computeRawVarint32Size(bytes.length) + - bytes.length; + final int length = encodedLength(value); + return computeRawVarint32Size(length) + length; } /** @@ -687,7 +881,7 @@ public final class CodedOutputByteBufferNano { * Otherwise, throws {@code UnsupportedOperationException}. */ public int spaceLeft() { - return limit - position; + return buffer.remaining(); } /** @@ -720,12 +914,12 @@ public final class CodedOutputByteBufferNano { /** Write a single byte. */ public void writeRawByte(final byte value) throws IOException { - if (position == limit) { + if (!buffer.hasRemaining()) { // We're writing to a single buffer. - throw new OutOfSpaceException(position, limit); + throw new OutOfSpaceException(buffer.position(), buffer.limit()); } - buffer[position++] = value; + buffer.put(value); } /** Write a single byte, represented by an integer value. */ @@ -741,13 +935,11 @@ public final class CodedOutputByteBufferNano { /** Write part of an array of bytes. */ public void writeRawBytes(final byte[] value, int offset, int length) throws IOException { - if (limit - position >= length) { - // We have room in the current buffer. - System.arraycopy(value, offset, buffer, position, length); - position += length; + if (buffer.remaining() >= length) { + buffer.put(value, offset, length); } else { // We're writing to a single buffer. - throw new OutOfSpaceException(position, limit); + throw new OutOfSpaceException(buffer.position(), buffer.limit()); } } diff --git a/javanano/src/test/java/com/google/protobuf/nano/NanoTest.java b/javanano/src/test/java/com/google/protobuf/nano/NanoTest.java index 7de84310c0..d60e94ff31 100644 --- a/javanano/src/test/java/com/google/protobuf/nano/NanoTest.java +++ b/javanano/src/test/java/com/google/protobuf/nano/NanoTest.java @@ -2300,6 +2300,42 @@ public class NanoTest extends TestCase { } } + public void testDifferentStringLengthsNano() throws Exception { + // Test string serialization roundtrip using strings of the following lengths, + // with ASCII and Unicode characters requiring different UTF-8 byte counts per + // char, hence causing the length delimiter varint to sometimes require more + // bytes for the Unicode strings than the ASCII string of the same length. + int[] lengths = new int[] { + 0, + 1, + (1 << 4) - 1, // 1 byte for ASCII and Unicode + (1 << 7) - 1, // 1 byte for ASCII, 2 bytes for Unicode + (1 << 11) - 1, // 2 bytes for ASCII and Unicode + (1 << 14) - 1, // 2 bytes for ASCII, 3 bytes for Unicode + (1 << 17) - 1, // 3 bytes for ASCII and Unicode + }; + for (int i : lengths) { + testEncodingOfString('q', i); // 1 byte per char + testEncodingOfString('\u07FF', i); // 2 bytes per char + testEncodingOfString('\u0981', i); // 3 bytes per char + } + } + + private void testEncodingOfString(char c, int length) throws InvalidProtocolBufferNanoException { + TestAllTypesNano testAllTypesNano = new TestAllTypesNano(); + final String fullString = fullString(c, length); + testAllTypesNano.optionalString = fullString; + final TestAllTypesNano resultNano = new TestAllTypesNano(); + MessageNano.mergeFrom(resultNano, MessageNano.toByteArray(testAllTypesNano)); + assertEquals(fullString, resultNano.optionalString); + } + + private String fullString(char c, int length) { + char[] result = new char[length]; + Arrays.fill(result, c); + return new String(result); + } + public void testNanoWithHasParseFrom() throws Exception { TestAllTypesNanoHas msg = null; // Test false on creation, after clear and upon empty parse.