From f20be839276cfc1129c12d89924164624ef3796d Mon Sep 17 00:00:00 2001 From: Jan Tattermusch Date: Fri, 24 Jan 2020 16:00:02 +0100 Subject: [PATCH] enforce recursion depth checking for unknown fields --- .../CodedInputStreamTest.cs | 63 ++++++++++++++++++- .../src/Google.Protobuf/CodedInputStream.cs | 31 ++++++++- csharp/src/Google.Protobuf/UnknownFieldSet.cs | 22 +++++-- 3 files changed, 108 insertions(+), 8 deletions(-) diff --git a/csharp/src/Google.Protobuf.Test/CodedInputStreamTest.cs b/csharp/src/Google.Protobuf.Test/CodedInputStreamTest.cs index ba65b328e8..5f360ff46e 100644 --- a/csharp/src/Google.Protobuf.Test/CodedInputStreamTest.cs +++ b/csharp/src/Google.Protobuf.Test/CodedInputStreamTest.cs @@ -33,6 +33,7 @@ using System; using System.IO; using Google.Protobuf.TestProtos; +using Proto2 = Google.Protobuf.TestProtos.Proto2; using NUnit.Framework; namespace Google.Protobuf @@ -337,6 +338,66 @@ namespace Google.Protobuf CodedInputStream input = CodedInputStream.CreateWithLimits(new MemoryStream(atRecursiveLimit.ToByteArray()), 1000000, CodedInputStream.DefaultRecursionLimit - 1); Assert.Throws(() => TestRecursiveMessage.Parser.ParseFrom(input)); } + + private static byte[] MakeMaliciousRecursionUnknownFieldsPayload(int recursionDepth) + { + // generate recursively nested groups that will be parsed as unknown fields + int unknownFieldNumber = 14; // an unused field number + MemoryStream ms = new MemoryStream(); + CodedOutputStream output = new CodedOutputStream(ms); + for (int i = 0; i < recursionDepth; i++) + { + output.WriteTag(WireFormat.MakeTag(unknownFieldNumber, WireFormat.WireType.StartGroup)); + } + for (int i = 0; i < recursionDepth; i++) + { + output.WriteTag(WireFormat.MakeTag(unknownFieldNumber, WireFormat.WireType.EndGroup)); + } + output.Flush(); + return ms.ToArray(); + } + + [Test] + public void MaliciousRecursion_UnknownFields() + { + byte[] payloadAtRecursiveLimit = MakeMaliciousRecursionUnknownFieldsPayload(CodedInputStream.DefaultRecursionLimit); + byte[] payloadBeyondRecursiveLimit = MakeMaliciousRecursionUnknownFieldsPayload(CodedInputStream.DefaultRecursionLimit + 1); + + Assert.DoesNotThrow(() => TestRecursiveMessage.Parser.ParseFrom(payloadAtRecursiveLimit)); + Assert.Throws(() => TestRecursiveMessage.Parser.ParseFrom(payloadBeyondRecursiveLimit)); + } + + [Test] + public void ReadGroup_WrongEndGroupTag() + { + int groupFieldNumber = Proto2.TestAllTypes.OptionalGroupFieldNumber; + + // write Proto2.TestAllTypes with "optional_group" set, but use wrong EndGroup closing tag + MemoryStream ms = new MemoryStream(); + CodedOutputStream output = new CodedOutputStream(ms); + output.WriteTag(WireFormat.MakeTag(groupFieldNumber, WireFormat.WireType.StartGroup)); + output.WriteGroup(new Proto2.TestAllTypes.Types.OptionalGroup { A = 12345 }); + // end group with different field number + output.WriteTag(WireFormat.MakeTag(groupFieldNumber + 1, WireFormat.WireType.EndGroup)); + output.Flush(); + var payload = ms.ToArray(); + + Assert.Throws(() => Proto2.TestAllTypes.Parser.ParseFrom(payload)); + } + + [Test] + public void ReadGroup_UnknownFields_WrongEndGroupTag() + { + MemoryStream ms = new MemoryStream(); + CodedOutputStream output = new CodedOutputStream(ms); + output.WriteTag(WireFormat.MakeTag(14, WireFormat.WireType.StartGroup)); + // end group with different field number + output.WriteTag(WireFormat.MakeTag(15, WireFormat.WireType.EndGroup)); + output.Flush(); + var payload = ms.ToArray(); + + Assert.Throws(() => TestRecursiveMessage.Parser.ParseFrom(payload)); + } [Test] public void SizeLimit() @@ -735,4 +796,4 @@ namespace Google.Protobuf } } } -} \ No newline at end of file +} diff --git a/csharp/src/Google.Protobuf/CodedInputStream.cs b/csharp/src/Google.Protobuf/CodedInputStream.cs index bea6bff34f..b9feda53cb 100644 --- a/csharp/src/Google.Protobuf/CodedInputStream.cs +++ b/csharp/src/Google.Protobuf/CodedInputStream.cs @@ -307,10 +307,17 @@ namespace Google.Protobuf throw InvalidProtocolBufferException.MoreDataAvailable(); } } - #endregion + internal void CheckLastTagWas(uint expectedTag) + { + if (lastTag != expectedTag) { + throw InvalidProtocolBufferException.InvalidEndTag(); + } + } + #endregion + #region Reading of tags etc - + /// /// Peeks at the next field tag. This is like calling , but the /// tag is not consumed. (So a subsequent call to will return the @@ -636,7 +643,27 @@ namespace Google.Protobuf throw InvalidProtocolBufferException.RecursionLimitExceeded(); } ++recursionDepth; + + uint tag = lastTag; + int fieldNumber = WireFormat.GetTagFieldNumber(tag); + builder.MergeFrom(this); + CheckLastTagWas(WireFormat.MakeTag(fieldNumber, WireFormat.WireType.EndGroup)); + --recursionDepth; + } + + /// + /// Reads an embedded group unknown field from the stream. + /// + internal void ReadGroup(int fieldNumber, UnknownFieldSet set) + { + if (recursionDepth >= recursionLimit) + { + throw InvalidProtocolBufferException.RecursionLimitExceeded(); + } + ++recursionDepth; + set.MergeGroupFrom(this); + CheckLastTagWas(WireFormat.MakeTag(fieldNumber, WireFormat.WireType.EndGroup)); --recursionDepth; } diff --git a/csharp/src/Google.Protobuf/UnknownFieldSet.cs b/csharp/src/Google.Protobuf/UnknownFieldSet.cs index d136cf1e65..7a2b6a00d2 100644 --- a/csharp/src/Google.Protobuf/UnknownFieldSet.cs +++ b/csharp/src/Google.Protobuf/UnknownFieldSet.cs @@ -215,12 +215,8 @@ namespace Google.Protobuf } case WireFormat.WireType.StartGroup: { - uint endTag = WireFormat.MakeTag(number, WireFormat.WireType.EndGroup); UnknownFieldSet set = new UnknownFieldSet(); - while (input.ReadTag() != endTag) - { - set.MergeFieldFrom(input); - } + input.ReadGroup(number, set); GetOrAddField(number).AddGroup(set); return true; } @@ -233,6 +229,22 @@ namespace Google.Protobuf } } + internal void MergeGroupFrom(CodedInputStream input) + { + while (true) + { + uint tag = input.ReadTag(); + if (tag == 0) + { + break; + } + if (!MergeFieldFrom(input)) + { + break; + } + } + } + /// /// Create a new UnknownFieldSet if unknownFields is null. /// Parse a single field from and merge it