CodedOutputStream: Avoid updating position to go beyond end of array.

This has twofold goals:
1. Correctness: if position overruns the array, checking space left may return a negative number. I'm not sure how bad that is, but let's avoid it.
2. Performance. This generates more optimal assembly code which can combine bounds checks, particularly on Android (I haven't looked at the generated assembly on the server JVM; it's possible the server JVM can already performance this hoist).

The `position` field is stored on the object, so Android ART generates assembly codes for `this.position++` like "load, add, store":

```
       ldr w3, [x1, #12]
       add w4, w3, #0x1 (1)
       str w4, [x1, #12]
```

There can be a lot of these loads/stores executed each step of a loop (e.g. writeFixed64NoTag updates position 8 times, and varint encoding could do it even more). It's faster if we can hoist these so we load once at the start of the function, and store once at the end of the function. This also has the nice benefit that it won't store if we've thrown an exception.

See before/after in Compiler Explorer: https://godbolt.org/z/bWWYqsxK4. I'm not an assembly expert, but it seems clear that the increment instructions like `add w4, w0, #0x1 (1)` are no longer always surrounded by loads and stores in the new version.

And in Compiler Explorer, you also see `bufferFixed64NoTag` has reduced from 98 lines of assembly to 57 lines of assembly in the hoisted version. This is because we don't need to re-check the array bounds each time we reload `position`. I imagine this also makes any other method with a fixed number of increments like `writeFixed32NoTag` faster too.

PiperOrigin-RevId: 673588324
pull/18226/head
Mark Hansen 6 months ago committed by Copybara-Service
parent f861f12841
commit 3f1de2c6e8
  1. 56
      java/core/src/main/java/com/google/protobuf/CodedOutputStream.java

@ -1306,12 +1306,14 @@ public abstract class CodedOutputStream extends ByteOutput {
@Override
public final void write(byte value) throws IOException {
int position = this.position;
try {
buffer[position++] = value;
} catch (IndexOutOfBoundsException e) {
throw new OutOfSpaceException(
String.format("Pos: %d, limit: %d, len: %d", position, limit, 1), e);
}
this.position = position; // Only update position if we stayed within the array bounds.
}
@Override
@ -1326,11 +1328,12 @@ public abstract class CodedOutputStream extends ByteOutput {
@Override
public final void writeUInt32NoTag(int value) throws IOException {
int position = this.position; // Perf: hoist field to register to avoid load/stores.
try {
while (true) {
if ((value & ~0x7F) == 0) {
buffer[position++] = (byte) value;
return;
break;
} else {
buffer[position++] = (byte) ((value | 0x80) & 0xFF);
value >>>= 7;
@ -1340,10 +1343,12 @@ public abstract class CodedOutputStream extends ByteOutput {
throw new OutOfSpaceException(
String.format("Pos: %d, limit: %d, len: %d", position, limit, 1), e);
}
this.position = position; // Only update position if we stayed within the array bounds.
}
@Override
public final void writeFixed32NoTag(int value) throws IOException {
int position = this.position; // Perf: hoist field to register to avoid load/stores.
try {
buffer[position++] = (byte) (value & 0xFF);
buffer[position++] = (byte) ((value >> 8) & 0xFF);
@ -1353,15 +1358,17 @@ public abstract class CodedOutputStream extends ByteOutput {
throw new OutOfSpaceException(
String.format("Pos: %d, limit: %d, len: %d", position, limit, 1), e);
}
this.position = position; // Only update position if we stayed within the array bounds.
}
@Override
public final void writeUInt64NoTag(long value) throws IOException {
int position = this.position; // Perf: hoist field to register to avoid load/stores.
if (HAS_UNSAFE_ARRAY_OPERATIONS && spaceLeft() >= MAX_VARINT_SIZE) {
while (true) {
if ((value & ~0x7FL) == 0) {
UnsafeUtil.putByte(buffer, position++, (byte) value);
return;
break;
} else {
UnsafeUtil.putByte(buffer, position++, (byte) (((int) value | 0x80) & 0xFF));
value >>>= 7;
@ -1372,7 +1379,7 @@ public abstract class CodedOutputStream extends ByteOutput {
while (true) {
if ((value & ~0x7FL) == 0) {
buffer[position++] = (byte) value;
return;
break;
} else {
buffer[position++] = (byte) (((int) value | 0x80) & 0xFF);
value >>>= 7;
@ -1383,10 +1390,12 @@ public abstract class CodedOutputStream extends ByteOutput {
String.format("Pos: %d, limit: %d, len: %d", position, limit, 1), e);
}
}
this.position = position; // Only update position if we stayed within the array bounds.
}
@Override
public final void writeFixed64NoTag(long value) throws IOException {
int position = this.position; // Perf: hoist field to register to avoid load/stores.
try {
buffer[position++] = (byte) ((int) (value) & 0xFF);
buffer[position++] = (byte) ((int) (value >> 8) & 0xFF);
@ -1400,17 +1409,18 @@ public abstract class CodedOutputStream extends ByteOutput {
throw new OutOfSpaceException(
String.format("Pos: %d, limit: %d, len: %d", position, limit, 1), e);
}
this.position = position; // Only update position if we stayed within the array bounds.
}
@Override
public final void write(byte[] value, int offset, int length) throws IOException {
try {
System.arraycopy(value, offset, buffer, position, length);
position += length;
} catch (IndexOutOfBoundsException e) {
throw new OutOfSpaceException(
String.format("Pos: %d, limit: %d, len: %d", position, limit, length), e);
}
position += length;
}
@Override
@ -1688,7 +1698,7 @@ public abstract class CodedOutputStream extends ByteOutput {
while (true) {
if ((value & ~0x7F) == 0) {
buffer.put((byte) value);
return;
break;
} else {
buffer.put((byte) ((value | 0x80) & 0xFF));
value >>>= 7;
@ -1714,7 +1724,7 @@ public abstract class CodedOutputStream extends ByteOutput {
while (true) {
if ((value & ~0x7FL) == 0) {
buffer.put((byte) value);
return;
break;
} else {
buffer.put((byte) (((int) value | 0x80) & 0xFF));
value >>>= 7;
@ -2014,30 +2024,34 @@ public abstract class CodedOutputStream extends ByteOutput {
@Override
public void writeUInt32NoTag(int value) throws IOException {
long position = this.position; // Perf: hoist field to register to avoid load/stores.
if (position <= oneVarintLimit) {
// Optimization to avoid bounds checks on each iteration.
while (true) {
if ((value & ~0x7F) == 0) {
UnsafeUtil.putByte(position++, (byte) value);
return;
break;
} else {
UnsafeUtil.putByte(position++, (byte) ((value | 0x80) & 0xFF));
value >>>= 7;
}
}
} else {
while (position < limit) {
while (true) {
if (position >= limit) {
throw new OutOfSpaceException(
String.format("Pos: %d, limit: %d, len: %d", position, limit, 1));
}
if ((value & ~0x7F) == 0) {
UnsafeUtil.putByte(position++, (byte) value);
return;
break;
} else {
UnsafeUtil.putByte(position++, (byte) ((value | 0x80) & 0xFF));
value >>>= 7;
}
}
throw new OutOfSpaceException(
String.format("Pos: %d, limit: %d, len: %d", position, limit, 1));
}
this.position = position; // Only update position if we stayed within the array bounds.
}
@Override
@ -2048,12 +2062,13 @@ public abstract class CodedOutputStream extends ByteOutput {
@Override
public void writeUInt64NoTag(long value) throws IOException {
long position = this.position; // Perf: hoist field to register to avoid load/stores.
if (position <= oneVarintLimit) {
// Optimization to avoid bounds checks on each iteration.
while (true) {
if ((value & ~0x7FL) == 0) {
UnsafeUtil.putByte(position++, (byte) value);
return;
break;
} else {
UnsafeUtil.putByte(position++, (byte) (((int) value | 0x80) & 0xFF));
value >>>= 7;
@ -2063,7 +2078,7 @@ public abstract class CodedOutputStream extends ByteOutput {
while (position < limit) {
if ((value & ~0x7FL) == 0) {
UnsafeUtil.putByte(position++, (byte) value);
return;
break;
} else {
UnsafeUtil.putByte(position++, (byte) (((int) value | 0x80) & 0xFF));
value >>>= 7;
@ -2072,6 +2087,7 @@ public abstract class CodedOutputStream extends ByteOutput {
throw new OutOfSpaceException(
String.format("Pos: %d, limit: %d, len: %d", position, limit, 1));
}
this.position = position; // Only update position if we stayed within the array bounds.
}
@Override
@ -2228,7 +2244,9 @@ public abstract class CodedOutputStream extends ByteOutput {
* responsibility of the caller.
*/
final void buffer(byte value) {
int position = this.position;
buffer[position++] = value;
this.position = position; // Only update position if we stayed within the array bounds.
totalBytesWritten++;
}
@ -2258,6 +2276,7 @@ public abstract class CodedOutputStream extends ByteOutput {
* responsibility of the caller.
*/
final void bufferUInt32NoTag(int value) {
int position = this.position; // Perf: hoist field to register to avoid load/stores.
if (HAS_UNSAFE_ARRAY_OPERATIONS) {
final long originalPos = position;
while (true) {
@ -2276,7 +2295,7 @@ public abstract class CodedOutputStream extends ByteOutput {
if ((value & ~0x7F) == 0) {
buffer[position++] = (byte) value;
totalBytesWritten++;
return;
break;
} else {
buffer[position++] = (byte) ((value | 0x80) & 0xFF);
totalBytesWritten++;
@ -2284,6 +2303,7 @@ public abstract class CodedOutputStream extends ByteOutput {
}
}
}
this.position = position; // Only update position if we stayed within the array bounds.
}
/**
@ -2291,6 +2311,7 @@ public abstract class CodedOutputStream extends ByteOutput {
* responsibility of the caller.
*/
final void bufferUInt64NoTag(long value) {
int position = this.position; // Perf: hoist field to register to avoid load/stores.
if (HAS_UNSAFE_ARRAY_OPERATIONS) {
final long originalPos = position;
while (true) {
@ -2309,7 +2330,7 @@ public abstract class CodedOutputStream extends ByteOutput {
if ((value & ~0x7FL) == 0) {
buffer[position++] = (byte) value;
totalBytesWritten++;
return;
break;
} else {
buffer[position++] = (byte) (((int) value | 0x80) & 0xFF);
totalBytesWritten++;
@ -2317,6 +2338,7 @@ public abstract class CodedOutputStream extends ByteOutput {
}
}
}
this.position = position; // Only update position if we stayed within the array bounds.
}
/**
@ -2324,10 +2346,12 @@ public abstract class CodedOutputStream extends ByteOutput {
* responsibility of the caller.
*/
final void bufferFixed32NoTag(int value) {
int position = this.position; // Perf: hoist field to register to avoid load/stores.
buffer[position++] = (byte) (value & 0xFF);
buffer[position++] = (byte) ((value >> 8) & 0xFF);
buffer[position++] = (byte) ((value >> 16) & 0xFF);
buffer[position++] = (byte) ((value >> 24) & 0xFF);
this.position = position; // Only update position if we stayed within the array bounds.
totalBytesWritten += FIXED32_SIZE;
}
@ -2336,6 +2360,7 @@ public abstract class CodedOutputStream extends ByteOutput {
* responsibility of the caller.
*/
final void bufferFixed64NoTag(long value) {
int position = this.position; // Perf: hoist field to register to avoid load/stores.
buffer[position++] = (byte) (value & 0xFF);
buffer[position++] = (byte) ((value >> 8) & 0xFF);
buffer[position++] = (byte) ((value >> 16) & 0xFF);
@ -2344,6 +2369,7 @@ public abstract class CodedOutputStream extends ByteOutput {
buffer[position++] = (byte) ((int) (value >> 40) & 0xFF);
buffer[position++] = (byte) ((int) (value >> 48) & 0xFF);
buffer[position++] = (byte) ((int) (value >> 56) & 0xFF);
this.position = position; // Only update position if we stayed within the array bounds.
totalBytesWritten += FIXED64_SIZE;
}
}

Loading…
Cancel
Save