From 921bdaaa61b8b62437a3cd999c1962c453ff937b Mon Sep 17 00:00:00 2001 From: James Newton-King Date: Thu, 23 Apr 2020 16:13:06 +1200 Subject: [PATCH] Improve repeated fixed parsing performance --- .../ParseMessagesBenchmark.cs | 39 ++++++++ .../Collections/RepeatedFieldTest.cs | 89 +++++++++++++++++++ .../ReadOnlySequenceFactory.cs | 4 +- .../Collections/RepeatedField.cs | 45 +++++++++- 4 files changed, 173 insertions(+), 4 deletions(-) diff --git a/csharp/src/Google.Protobuf.Benchmarks/ParseMessagesBenchmark.cs b/csharp/src/Google.Protobuf.Benchmarks/ParseMessagesBenchmark.cs index 30f3e9ef37..92c3dbe5f0 100644 --- a/csharp/src/Google.Protobuf.Benchmarks/ParseMessagesBenchmark.cs +++ b/csharp/src/Google.Protobuf.Benchmarks/ParseMessagesBenchmark.cs @@ -37,6 +37,7 @@ using System.IO; using System.Linq; using System.Buffers; using Google.Protobuf.WellKnownTypes; +using Benchmarks.Proto3; namespace Google.Protobuf.Benchmarks { @@ -50,6 +51,7 @@ namespace Google.Protobuf.Benchmarks SubTest manyWrapperFieldsTest = new SubTest(CreateManyWrapperFieldsMessage(), ManyWrapperFieldsMessage.Parser, () => new ManyWrapperFieldsMessage(), MaxMessages); SubTest manyPrimitiveFieldsTest = new SubTest(CreateManyPrimitiveFieldsMessage(), ManyPrimitiveFieldsMessage.Parser, () => new ManyPrimitiveFieldsMessage(), MaxMessages); + SubTest repeatedFieldTest = new SubTest(CreateRepeatedFieldMessage(), GoogleMessage1.Parser, () => new GoogleMessage1(), MaxMessages); SubTest emptyMessageTest = new SubTest(new Empty(), Empty.Parser, () => new Empty(), MaxMessages); public IEnumerable MessageCountValues => new[] { 10, 100 }; @@ -83,6 +85,18 @@ namespace Google.Protobuf.Benchmarks return manyPrimitiveFieldsTest.ParseFromReadOnlySequence(); } + [Benchmark] + public IMessage RepeatedFieldMessage_ParseFromByteArray() + { + return repeatedFieldTest.ParseFromByteArray(); + } + + [Benchmark] + public IMessage RepeatedFieldMessage_ParseFromReadOnlySequence() + { + return repeatedFieldTest.ParseFromReadOnlySequence(); + } + [Benchmark] public IMessage EmptyMessage_ParseFromByteArray() { @@ -123,6 +137,20 @@ namespace Google.Protobuf.Benchmarks manyPrimitiveFieldsTest.ParseDelimitedMessagesFromReadOnlySequence(messageCount); } + [Benchmark] + [ArgumentsSource(nameof(MessageCountValues))] + public void RepeatedFieldMessage_ParseDelimitedMessagesFromByteArray(int messageCount) + { + repeatedFieldTest.ParseDelimitedMessagesFromByteArray(messageCount); + } + + [Benchmark] + [ArgumentsSource(nameof(MessageCountValues))] + public void RepeatedFieldMessage_ParseDelimitedMessagesFromReadOnlySequence(int messageCount) + { + repeatedFieldTest.ParseDelimitedMessagesFromReadOnlySequence(messageCount); + } + private static ManyWrapperFieldsMessage CreateManyWrapperFieldsMessage() { // Example data match data of an internal benchmarks @@ -157,6 +185,17 @@ namespace Google.Protobuf.Benchmarks }; } + private static GoogleMessage1 CreateRepeatedFieldMessage() + { + // Message with a repeated fixed length item collection + var message = new GoogleMessage1(); + for (ulong i = 0; i < 1000; i++) + { + message.Field5.Add(i); + } + return message; + } + private class SubTest { private readonly IMessage message; diff --git a/csharp/src/Google.Protobuf.Test/Collections/RepeatedFieldTest.cs b/csharp/src/Google.Protobuf.Test/Collections/RepeatedFieldTest.cs index 527c7d404c..61453e5ab2 100644 --- a/csharp/src/Google.Protobuf.Test/Collections/RepeatedFieldTest.cs +++ b/csharp/src/Google.Protobuf.Test/Collections/RepeatedFieldTest.cs @@ -595,6 +595,95 @@ namespace Google.Protobuf.Collections Assert.AreEqual(((SampleEnum)(-5)), values[5]); } + [Test] + public void TestPackedRepeatedFieldCollectionNonDivisibleLength() + { + uint tag = WireFormat.MakeTag(10, WireFormat.WireType.LengthDelimited); + var codec = FieldCodec.ForFixed32(tag); + var stream = new MemoryStream(); + var output = new CodedOutputStream(stream); + output.WriteTag(tag); + output.WriteString("A long string"); + output.WriteTag(codec.Tag); + output.WriteRawVarint32((uint)codec.FixedSize - 1); // Length not divisible by FixedSize + output.WriteFixed32(uint.MaxValue); + output.Flush(); + stream.Position = 0; + + var input = new CodedInputStream(stream); + input.ReadTag(); + input.ReadString(); + input.ReadTag(); + var field = new RepeatedField(); + Assert.Throws(() => field.AddEntriesFrom(input, codec)); + + // Collection was not pre-initialized + Assert.AreEqual(0, field.Count); + } + + [Test] + public void TestPackedRepeatedFieldCollectionNotAllocatedWhenLengthExceedsBuffer() + { + uint tag = WireFormat.MakeTag(10, WireFormat.WireType.LengthDelimited); + var codec = FieldCodec.ForFixed32(tag); + var stream = new MemoryStream(); + var output = new CodedOutputStream(stream); + output.WriteTag(tag); + output.WriteString("A long string"); + output.WriteTag(codec.Tag); + output.WriteRawVarint32((uint)codec.FixedSize); + // Note that there is no content for the packed field. + // The field length exceeds the remaining length of content. + output.Flush(); + stream.Position = 0; + + var input = new CodedInputStream(stream); + input.ReadTag(); + input.ReadString(); + input.ReadTag(); + var field = new RepeatedField(); + Assert.Throws(() => field.AddEntriesFrom(input, codec)); + + // Collection was not pre-initialized + Assert.AreEqual(0, field.Count); + } + + [Test] + public void TestPackedRepeatedFieldCollectionNotAllocatedWhenLengthExceedsRemainingData() + { + uint tag = WireFormat.MakeTag(10, WireFormat.WireType.LengthDelimited); + var codec = FieldCodec.ForFixed32(tag); + var stream = new MemoryStream(); + var output = new CodedOutputStream(stream); + output.WriteTag(tag); + output.WriteString("A long string"); + output.WriteTag(codec.Tag); + output.WriteRawVarint32((uint)codec.FixedSize); + // Note that there is no content for the packed field. + // The field length exceeds the remaining length of the buffer. + output.Flush(); + stream.Position = 0; + + var sequence = ReadOnlySequenceFactory.CreateWithContent(stream.ToArray()); + ParseContext.Initialize(sequence, out ParseContext ctx); + + ctx.ReadTag(); + ctx.ReadString(); + ctx.ReadTag(); + var field = new RepeatedField(); + try + { + field.AddEntriesFrom(ref ctx, codec); + Assert.Fail(); + } + catch (InvalidProtocolBufferException) + { + } + + // Collection was not pre-initialized + Assert.AreEqual(0, field.Count); + } + // Fairly perfunctory tests for the non-generic IList implementation [Test] public void IList_Indexer() diff --git a/csharp/src/Google.Protobuf.Test/ReadOnlySequenceFactory.cs b/csharp/src/Google.Protobuf.Test/ReadOnlySequenceFactory.cs index d2050d37ca..588b559e9f 100644 --- a/csharp/src/Google.Protobuf.Test/ReadOnlySequenceFactory.cs +++ b/csharp/src/Google.Protobuf.Test/ReadOnlySequenceFactory.cs @@ -50,9 +50,9 @@ namespace Google.Protobuf while (currentIndex < data.Length) { var segment = new List(); - for (; currentIndex < Math.Min(currentIndex + segmentSize, data.Length); currentIndex++) + while (segment.Count < segmentSize && currentIndex < data.Length) { - segment.Add(data[currentIndex]); + segment.Add(data[currentIndex++]); } segments.Add(segment.ToArray()); segments.Add(new byte[0]); diff --git a/csharp/src/Google.Protobuf/Collections/RepeatedField.cs b/csharp/src/Google.Protobuf/Collections/RepeatedField.cs index 5c39aabafb..b1bd9b13d1 100644 --- a/csharp/src/Google.Protobuf/Collections/RepeatedField.cs +++ b/csharp/src/Google.Protobuf/Collections/RepeatedField.cs @@ -35,6 +35,7 @@ using System.Collections; using System.Collections.Generic; using System.IO; using System.Security; +using System.Threading; namespace Google.Protobuf.Collections { @@ -126,9 +127,31 @@ namespace Google.Protobuf.Collections if (length > 0) { int oldLimit = SegmentedBufferHelper.PushLimit(ref ctx.state, length); - while (!SegmentedBufferHelper.IsReachedLimit(ref ctx.state)) + + // If the content is fixed size then we can calculate the length + // of the repeated field and pre-initialize the underlying collection. + // + // Check that the supplied length doesn't exceed the underlying buffer. + // That prevents a malicious length from initializing a very large collection. + if (codec.FixedSize > 0 && length % codec.FixedSize == 0 && IsDataAvailable(ref ctx, length)) + { + EnsureSize(count + (length / codec.FixedSize)); + + while (!SegmentedBufferHelper.IsReachedLimit(ref ctx.state)) + { + // Only FieldCodecs with a fixed size can reach here, and they are all known + // types that don't allow the user to specify a custom reader action. + // reader action will never return null. + array[count++] = reader(ref ctx); + } + } + else { - Add(reader(ref ctx)); + // Content is variable size so add until we reach the limit. + while (!SegmentedBufferHelper.IsReachedLimit(ref ctx.state)) + { + Add(reader(ref ctx)); + } } SegmentedBufferHelper.PopLimit(ref ctx.state, oldLimit); } @@ -144,6 +167,24 @@ namespace Google.Protobuf.Collections } } + private bool IsDataAvailable(ref ParseContext ctx, int size) + { + // Data fits in remaining buffer + if (size <= ctx.state.bufferSize - ctx.state.bufferPos) + { + return true; + } + + // Data fits in remaining source data. + // Note that this will never be true when reading from a stream as the total length is unknown. + if (size < ctx.state.segmentedBufferHelper.TotalLength - ctx.state.totalBytesRetired - ctx.state.bufferPos) + { + return true; + } + + return false; + } + /// /// Calculates the size of this collection based on the given codec. ///