From 921bdaaa61b8b62437a3cd999c1962c453ff937b Mon Sep 17 00:00:00 2001
From: James Newton-King <james@newtonking.com>
Date: Thu, 23 Apr 2020 16:13:06 +1200
Subject: [PATCH] Improve repeated fixed parsing performance

---
 .../ParseMessagesBenchmark.cs                 | 39 ++++++++
 .../Collections/RepeatedFieldTest.cs          | 89 +++++++++++++++++++
 .../ReadOnlySequenceFactory.cs                |  4 +-
 .../Collections/RepeatedField.cs              | 45 +++++++++-
 4 files changed, 173 insertions(+), 4 deletions(-)

diff --git a/csharp/src/Google.Protobuf.Benchmarks/ParseMessagesBenchmark.cs b/csharp/src/Google.Protobuf.Benchmarks/ParseMessagesBenchmark.cs
index 30f3e9ef37..92c3dbe5f0 100644
--- a/csharp/src/Google.Protobuf.Benchmarks/ParseMessagesBenchmark.cs
+++ b/csharp/src/Google.Protobuf.Benchmarks/ParseMessagesBenchmark.cs
@@ -37,6 +37,7 @@ using System.IO;
 using System.Linq;
 using System.Buffers;
 using Google.Protobuf.WellKnownTypes;
+using Benchmarks.Proto3;
 
 namespace Google.Protobuf.Benchmarks
 {
@@ -50,6 +51,7 @@ namespace Google.Protobuf.Benchmarks
 
         SubTest manyWrapperFieldsTest = new SubTest(CreateManyWrapperFieldsMessage(), ManyWrapperFieldsMessage.Parser, () => new ManyWrapperFieldsMessage(), MaxMessages);
         SubTest manyPrimitiveFieldsTest = new SubTest(CreateManyPrimitiveFieldsMessage(), ManyPrimitiveFieldsMessage.Parser, () => new ManyPrimitiveFieldsMessage(), MaxMessages);
+        SubTest repeatedFieldTest = new SubTest(CreateRepeatedFieldMessage(), GoogleMessage1.Parser, () => new GoogleMessage1(), MaxMessages);
         SubTest emptyMessageTest = new SubTest(new Empty(), Empty.Parser, () => new Empty(), MaxMessages);
 
         public IEnumerable<int> MessageCountValues => new[] { 10, 100 };
@@ -83,6 +85,18 @@ namespace Google.Protobuf.Benchmarks
             return manyPrimitiveFieldsTest.ParseFromReadOnlySequence();
         }
 
+        [Benchmark]
+        public IMessage RepeatedFieldMessage_ParseFromByteArray()
+        {
+            return repeatedFieldTest.ParseFromByteArray();
+        }
+
+        [Benchmark]
+        public IMessage RepeatedFieldMessage_ParseFromReadOnlySequence()
+        {
+            return repeatedFieldTest.ParseFromReadOnlySequence();
+        }
+
         [Benchmark]
         public IMessage EmptyMessage_ParseFromByteArray()
         {
@@ -123,6 +137,20 @@ namespace Google.Protobuf.Benchmarks
             manyPrimitiveFieldsTest.ParseDelimitedMessagesFromReadOnlySequence(messageCount);
         }
 
+        [Benchmark]
+        [ArgumentsSource(nameof(MessageCountValues))]
+        public void RepeatedFieldMessage_ParseDelimitedMessagesFromByteArray(int messageCount)
+        {
+            repeatedFieldTest.ParseDelimitedMessagesFromByteArray(messageCount);
+        }
+
+        [Benchmark]
+        [ArgumentsSource(nameof(MessageCountValues))]
+        public void RepeatedFieldMessage_ParseDelimitedMessagesFromReadOnlySequence(int messageCount)
+        {
+            repeatedFieldTest.ParseDelimitedMessagesFromReadOnlySequence(messageCount);
+        }
+
         private static ManyWrapperFieldsMessage CreateManyWrapperFieldsMessage()
         {
             // Example data match data of an internal benchmarks
@@ -157,6 +185,17 @@ namespace Google.Protobuf.Benchmarks
             };
         }
 
+        private static GoogleMessage1 CreateRepeatedFieldMessage()
+        {
+            // Message with a repeated fixed length item collection
+            var message = new GoogleMessage1();
+            for (ulong i = 0; i < 1000; i++)
+            {
+                message.Field5.Add(i);
+            }
+            return message;
+        }
+
         private class SubTest
         {
             private readonly IMessage message;
diff --git a/csharp/src/Google.Protobuf.Test/Collections/RepeatedFieldTest.cs b/csharp/src/Google.Protobuf.Test/Collections/RepeatedFieldTest.cs
index 527c7d404c..61453e5ab2 100644
--- a/csharp/src/Google.Protobuf.Test/Collections/RepeatedFieldTest.cs
+++ b/csharp/src/Google.Protobuf.Test/Collections/RepeatedFieldTest.cs
@@ -595,6 +595,95 @@ namespace Google.Protobuf.Collections
             Assert.AreEqual(((SampleEnum)(-5)), values[5]);
         }
 
+        [Test]
+        public void TestPackedRepeatedFieldCollectionNonDivisibleLength()
+        {
+            uint tag = WireFormat.MakeTag(10, WireFormat.WireType.LengthDelimited);
+            var codec = FieldCodec.ForFixed32(tag);
+            var stream = new MemoryStream();
+            var output = new CodedOutputStream(stream);
+            output.WriteTag(tag);
+            output.WriteString("A long string");
+            output.WriteTag(codec.Tag);
+            output.WriteRawVarint32((uint)codec.FixedSize - 1); // Length not divisible by FixedSize
+            output.WriteFixed32(uint.MaxValue);
+            output.Flush();
+            stream.Position = 0;
+
+            var input = new CodedInputStream(stream);
+            input.ReadTag();
+            input.ReadString();
+            input.ReadTag();
+            var field = new RepeatedField<uint>();
+            Assert.Throws<InvalidProtocolBufferException>(() => field.AddEntriesFrom(input, codec));
+
+            // Collection was not pre-initialized
+            Assert.AreEqual(0, field.Count);
+        }
+
+        [Test]
+        public void TestPackedRepeatedFieldCollectionNotAllocatedWhenLengthExceedsBuffer()
+        {
+            uint tag = WireFormat.MakeTag(10, WireFormat.WireType.LengthDelimited);
+            var codec = FieldCodec.ForFixed32(tag);
+            var stream = new MemoryStream();
+            var output = new CodedOutputStream(stream);
+            output.WriteTag(tag);
+            output.WriteString("A long string");
+            output.WriteTag(codec.Tag);
+            output.WriteRawVarint32((uint)codec.FixedSize);
+            // Note that there is no content for the packed field.
+            // The field length exceeds the remaining length of content.
+            output.Flush();
+            stream.Position = 0;
+
+            var input = new CodedInputStream(stream);
+            input.ReadTag();
+            input.ReadString();
+            input.ReadTag();
+            var field = new RepeatedField<uint>();
+            Assert.Throws<InvalidProtocolBufferException>(() => field.AddEntriesFrom(input, codec));
+
+            // Collection was not pre-initialized
+            Assert.AreEqual(0, field.Count);
+        }
+
+        [Test]
+        public void TestPackedRepeatedFieldCollectionNotAllocatedWhenLengthExceedsRemainingData()
+        {
+            uint tag = WireFormat.MakeTag(10, WireFormat.WireType.LengthDelimited);
+            var codec = FieldCodec.ForFixed32(tag);
+            var stream = new MemoryStream();
+            var output = new CodedOutputStream(stream);
+            output.WriteTag(tag);
+            output.WriteString("A long string");
+            output.WriteTag(codec.Tag);
+            output.WriteRawVarint32((uint)codec.FixedSize);
+            // Note that there is no content for the packed field.
+            // The field length exceeds the remaining length of the buffer.
+            output.Flush();
+            stream.Position = 0;
+
+            var sequence = ReadOnlySequenceFactory.CreateWithContent(stream.ToArray());
+            ParseContext.Initialize(sequence, out ParseContext ctx);
+
+            ctx.ReadTag();
+            ctx.ReadString();
+            ctx.ReadTag();
+            var field = new RepeatedField<uint>();
+            try
+            {
+                field.AddEntriesFrom(ref ctx, codec);
+                Assert.Fail();
+            }
+            catch (InvalidProtocolBufferException)
+            {
+            }
+
+            // Collection was not pre-initialized
+            Assert.AreEqual(0, field.Count);
+        }
+
         // Fairly perfunctory tests for the non-generic IList implementation
         [Test]
         public void IList_Indexer()
diff --git a/csharp/src/Google.Protobuf.Test/ReadOnlySequenceFactory.cs b/csharp/src/Google.Protobuf.Test/ReadOnlySequenceFactory.cs
index d2050d37ca..588b559e9f 100644
--- a/csharp/src/Google.Protobuf.Test/ReadOnlySequenceFactory.cs
+++ b/csharp/src/Google.Protobuf.Test/ReadOnlySequenceFactory.cs
@@ -50,9 +50,9 @@ namespace Google.Protobuf
             while (currentIndex < data.Length)
             {
                 var segment = new List<byte>();
-                for (; currentIndex < Math.Min(currentIndex + segmentSize, data.Length); currentIndex++)
+                while (segment.Count < segmentSize && currentIndex < data.Length)
                 {
-                    segment.Add(data[currentIndex]);
+                    segment.Add(data[currentIndex++]);
                 }
                 segments.Add(segment.ToArray());
                 segments.Add(new byte[0]);
diff --git a/csharp/src/Google.Protobuf/Collections/RepeatedField.cs b/csharp/src/Google.Protobuf/Collections/RepeatedField.cs
index 5c39aabafb..b1bd9b13d1 100644
--- a/csharp/src/Google.Protobuf/Collections/RepeatedField.cs
+++ b/csharp/src/Google.Protobuf/Collections/RepeatedField.cs
@@ -35,6 +35,7 @@ using System.Collections;
 using System.Collections.Generic;
 using System.IO;
 using System.Security;
+using System.Threading;
 
 namespace Google.Protobuf.Collections
 {
@@ -126,9 +127,31 @@ namespace Google.Protobuf.Collections
                 if (length > 0)
                 {
                     int oldLimit = SegmentedBufferHelper.PushLimit(ref ctx.state, length);
-                    while (!SegmentedBufferHelper.IsReachedLimit(ref ctx.state))
+
+                    // If the content is fixed size then we can calculate the length
+                    // of the repeated field and pre-initialize the underlying collection.
+                    //
+                    // Check that the supplied length doesn't exceed the underlying buffer.
+                    // That prevents a malicious length from initializing a very large collection.
+                    if (codec.FixedSize > 0 && length % codec.FixedSize == 0 && IsDataAvailable(ref ctx, length))
+                    {
+                        EnsureSize(count + (length / codec.FixedSize));
+
+                        while (!SegmentedBufferHelper.IsReachedLimit(ref ctx.state))
+                        {
+                            // Only FieldCodecs with a fixed size can reach here, and they are all known
+                            // types that don't allow the user to specify a custom reader action.
+                            // reader action will never return null.
+                            array[count++] = reader(ref ctx);
+                        }
+                    }
+                    else
                     {
-                        Add(reader(ref ctx));
+                        // Content is variable size so add until we reach the limit.
+                        while (!SegmentedBufferHelper.IsReachedLimit(ref ctx.state))
+                        {
+                            Add(reader(ref ctx));
+                        }
                     }
                     SegmentedBufferHelper.PopLimit(ref ctx.state, oldLimit);
                 }
@@ -144,6 +167,24 @@ namespace Google.Protobuf.Collections
             }
         }
 
+        private bool IsDataAvailable(ref ParseContext ctx, int size)
+        {
+            // Data fits in remaining buffer
+            if (size <= ctx.state.bufferSize - ctx.state.bufferPos)
+            {
+                return true;
+            }
+
+            // Data fits in remaining source data.
+            // Note that this will never be true when reading from a stream as the total length is unknown.
+            if (size < ctx.state.segmentedBufferHelper.TotalLength - ctx.state.totalBytesRetired - ctx.state.bufferPos)
+            {
+                return true;
+            }
+
+            return false;
+        }
+
         /// <summary>
         /// Calculates the size of this collection based on the given codec.
         /// </summary>