diff --git a/conformance/binary_json_conformance_suite.cc b/conformance/binary_json_conformance_suite.cc index aeb26a828e..f62c705f62 100644 --- a/conformance/binary_json_conformance_suite.cc +++ b/conformance/binary_json_conformance_suite.cc @@ -1865,6 +1865,16 @@ void BinaryAndJsonConformanceSuite::RunJsonTestsForFieldNameConvention() { })", [](const Json::Value& value) { return !value.isMember("FieldName13"); }, true); + RunValidJsonTestWithValidator( + "FieldNameExtension", RECOMMENDED, + R"({ + "[protobuf_test_messages.proto2.extension_int32]": 1 + })", + [](const Json::Value& value) { + return value.isMember( + "[protobuf_test_messages.proto2.extension_int32]"); + }, + false); } void BinaryAndJsonConformanceSuite::RunJsonTestsForNonRepeatedTypes() { diff --git a/conformance/failure_list_cpp.txt b/conformance/failure_list_cpp.txt index 0c01e1e40c..d55fa9ff34 100644 --- a/conformance/failure_list_cpp.txt +++ b/conformance/failure_list_cpp.txt @@ -34,3 +34,4 @@ Recommended.Proto3.JsonInput.TrailingCommaInAnObject Recommended.Proto3.JsonInput.TrailingCommaInAnObjectWithNewlines Recommended.Proto3.JsonInput.TrailingCommaInAnObjectWithSpace Recommended.Proto3.JsonInput.TrailingCommaInAnObjectWithSpaceCommaSpace +Recommended.Proto2.JsonInput.FieldNameExtension.Validator diff --git a/conformance/failure_list_csharp.txt b/conformance/failure_list_csharp.txt index 2a20aa78e7..ed6ee97793 100644 --- a/conformance/failure_list_csharp.txt +++ b/conformance/failure_list_csharp.txt @@ -1,2 +1,3 @@ Recommended.Proto3.JsonInput.BytesFieldBase64Url.JsonOutput Recommended.Proto3.JsonInput.BytesFieldBase64Url.ProtobufOutput +Required.Proto2.JsonInput.StoresDefaultPrimitive.Validator diff --git a/conformance/failure_list_java.txt b/conformance/failure_list_java.txt index 394c3659c6..e84e0bac6d 100644 --- a/conformance/failure_list_java.txt +++ b/conformance/failure_list_java.txt @@ -34,6 +34,7 @@ Recommended.Proto3.JsonInput.StringFieldUnpairedHighSurrogate Recommended.Proto3.JsonInput.StringFieldUnpairedLowSurrogate Recommended.Proto3.JsonInput.Uint32MapFieldKeyNotQuoted Recommended.Proto3.JsonInput.Uint64MapFieldKeyNotQuoted +Recommended.Proto2.JsonInput.FieldNameExtension.Validator Required.Proto3.JsonInput.EnumFieldNotQuoted Required.Proto3.JsonInput.Int32FieldLeadingZero Required.Proto3.JsonInput.Int32FieldNegativeWithLeadingZero diff --git a/conformance/failure_list_php.txt b/conformance/failure_list_php.txt index 7ee5b9aae1..70c668a9e2 100644 --- a/conformance/failure_list_php.txt +++ b/conformance/failure_list_php.txt @@ -61,6 +61,7 @@ Recommended.Proto3.ProtobufInput.ValidDataRepeated.UINT64.PackedInput.DefaultOut Recommended.Proto3.ProtobufInput.ValidDataRepeated.UINT64.PackedInput.PackedOutput.ProtobufOutput Recommended.Proto3.ProtobufInput.ValidDataRepeated.UINT64.UnpackedInput.DefaultOutput.ProtobufOutput Recommended.Proto3.ProtobufInput.ValidDataRepeated.UINT64.UnpackedInput.PackedOutput.ProtobufOutput +Required.Proto2.JsonInput.StoresDefaultPrimitive.Validator Required.Proto3.JsonInput.DoubleFieldTooSmall Required.Proto3.JsonInput.FloatFieldTooLarge Required.Proto3.JsonInput.FloatFieldTooSmall @@ -75,8 +76,8 @@ Required.Proto3.JsonInput.Uint32FieldNotInteger Required.Proto3.JsonInput.Uint64FieldNotInteger Required.Proto3.ProtobufInput.RepeatedScalarMessageMerge.JsonOutput Required.Proto3.ProtobufInput.RepeatedScalarMessageMerge.ProtobufOutput +Required.Proto3.ProtobufInput.ValidDataOneof.MESSAGE.Merge.JsonOutput +Required.Proto3.ProtobufInput.ValidDataOneof.MESSAGE.Merge.ProtobufOutput Required.Proto3.ProtobufInput.ValidDataRepeated.FLOAT.PackedInput.JsonOutput Required.Proto3.ProtobufInput.ValidDataRepeated.FLOAT.UnpackedInput.JsonOutput Required.Proto3.ProtobufInput.ValidDataScalar.FLOAT[2].JsonOutput -Required.Proto3.ProtobufInput.ValidDataOneof.MESSAGE.Merge.JsonOutput -Required.Proto3.ProtobufInput.ValidDataOneof.MESSAGE.Merge.ProtobufOutput diff --git a/conformance/failure_list_php_c.txt b/conformance/failure_list_php_c.txt index 6a6949fe40..d9e3e60c05 100644 --- a/conformance/failure_list_php_c.txt +++ b/conformance/failure_list_php_c.txt @@ -79,6 +79,7 @@ Recommended.Proto3.ProtobufInput.ValidDataScalarBinary.ENUM[3].ProtobufOutput Recommended.Proto3.ProtobufInput.ValidDataScalarBinary.ENUM[4].ProtobufOutput Required.DurationProtoInputTooLarge.JsonOutput Required.DurationProtoInputTooSmall.JsonOutput +Required.Proto2.JsonInput.StoresDefaultPrimitive.Validator Required.Proto3.JsonInput.DoubleFieldMaxNegativeValue.JsonOutput Required.Proto3.JsonInput.DoubleFieldMaxNegativeValue.ProtobufOutput Required.Proto3.JsonInput.DoubleFieldMinPositiveValue.JsonOutput diff --git a/conformance/failure_list_php_c_32.txt b/conformance/failure_list_php_c_32.txt index f516f13ac3..b3e20e0070 100644 --- a/conformance/failure_list_php_c_32.txt +++ b/conformance/failure_list_php_c_32.txt @@ -84,6 +84,7 @@ Recommended.Proto3.ProtobufInput.ValidDataScalarBinary.ENUM[3].ProtobufOutput Recommended.Proto3.ProtobufInput.ValidDataScalarBinary.ENUM[4].ProtobufOutput Required.DurationProtoInputTooLarge.JsonOutput Required.DurationProtoInputTooSmall.JsonOutput +Required.Proto2.JsonInput.StoresDefaultPrimitive.Validator Required.Proto3.JsonInput.DoubleFieldMaxNegativeValue.JsonOutput Required.Proto3.JsonInput.DoubleFieldMaxNegativeValue.ProtobufOutput Required.Proto3.JsonInput.DoubleFieldMinPositiveValue.JsonOutput @@ -99,13 +100,13 @@ Required.Proto3.JsonInput.FloatFieldInfinity.JsonOutput Required.Proto3.JsonInput.FloatFieldNan.JsonOutput Required.Proto3.JsonInput.FloatFieldNegativeInfinity.JsonOutput Required.Proto3.JsonInput.Int64FieldMaxValue.JsonOutput -Required.Proto3.JsonInput.Int64FieldMaxValue.ProtobufOutput Required.Proto3.JsonInput.Int64FieldMaxValueNotQuoted.JsonOutput Required.Proto3.JsonInput.Int64FieldMaxValueNotQuoted.ProtobufOutput +Required.Proto3.JsonInput.Int64FieldMaxValue.ProtobufOutput Required.Proto3.JsonInput.Int64FieldMinValue.JsonOutput -Required.Proto3.JsonInput.Int64FieldMinValue.ProtobufOutput Required.Proto3.JsonInput.Int64FieldMinValueNotQuoted.JsonOutput Required.Proto3.JsonInput.Int64FieldMinValueNotQuoted.ProtobufOutput +Required.Proto3.JsonInput.Int64FieldMinValue.ProtobufOutput Required.Proto3.JsonInput.OneofFieldDuplicate Required.Proto3.JsonInput.RejectTopLevelNull Required.Proto3.JsonInput.StringFieldSurrogatePair.JsonOutput @@ -123,9 +124,9 @@ Required.Proto3.JsonInput.TimestampWithNegativeOffset.ProtobufOutput Required.Proto3.JsonInput.TimestampWithPositiveOffset.JsonOutput Required.Proto3.JsonInput.TimestampWithPositiveOffset.ProtobufOutput Required.Proto3.JsonInput.Uint64FieldMaxValue.JsonOutput -Required.Proto3.JsonInput.Uint64FieldMaxValue.ProtobufOutput Required.Proto3.JsonInput.Uint64FieldMaxValueNotQuoted.JsonOutput Required.Proto3.JsonInput.Uint64FieldMaxValueNotQuoted.ProtobufOutput +Required.Proto3.JsonInput.Uint64FieldMaxValue.ProtobufOutput Required.Proto3.ProtobufInput.DoubleFieldNormalizeQuietNan.JsonOutput Required.Proto3.ProtobufInput.DoubleFieldNormalizeSignalingNan.JsonOutput Required.Proto3.ProtobufInput.FloatFieldNormalizeQuietNan.JsonOutput diff --git a/conformance/failure_list_ruby.txt b/conformance/failure_list_ruby.txt index f9533ae8ff..6b094371db 100644 --- a/conformance/failure_list_ruby.txt +++ b/conformance/failure_list_ruby.txt @@ -77,6 +77,7 @@ Recommended.Proto3.ProtobufInput.ValidDataScalarBinary.ENUM[3].ProtobufOutput Recommended.Proto3.ProtobufInput.ValidDataScalarBinary.ENUM[4].ProtobufOutput Required.DurationProtoInputTooLarge.JsonOutput Required.DurationProtoInputTooSmall.JsonOutput +Required.Proto2.JsonInput.StoresDefaultPrimitive.Validator Required.Proto3.JsonInput.DoubleFieldMaxNegativeValue.JsonOutput Required.Proto3.JsonInput.DoubleFieldMaxNegativeValue.ProtobufOutput Required.Proto3.JsonInput.DoubleFieldMinPositiveValue.JsonOutput diff --git a/csharp/src/Google.Protobuf.Test/Reflection/FieldAccessTest.cs b/csharp/src/Google.Protobuf.Test/Reflection/FieldAccessTest.cs index 0d4034c5b1..b4dcdabdc7 100644 --- a/csharp/src/Google.Protobuf.Test/Reflection/FieldAccessTest.cs +++ b/csharp/src/Google.Protobuf.Test/Reflection/FieldAccessTest.cs @@ -98,7 +98,48 @@ namespace Google.Protobuf.Reflection } [Test] - public void HasValue_Proto3() + public void HasValue_Proto3_Message() + { + var message = new TestAllTypes(); + var accessor = ((IMessage) message).Descriptor.Fields[TestProtos.TestAllTypes.SingleForeignMessageFieldNumber].Accessor; + Assert.False(accessor.HasValue(message)); + message.SingleForeignMessage = new ForeignMessage(); + Assert.True(accessor.HasValue(message)); + message.SingleForeignMessage = null; + Assert.False(accessor.HasValue(message)); + } + + [Test] + public void HasValue_Proto3_Oneof() + { + TestAllTypes message = new TestAllTypes(); + var accessor = ((IMessage) message).Descriptor.Fields[TestProtos.TestAllTypes.OneofStringFieldNumber].Accessor; + Assert.False(accessor.HasValue(message)); + // Even though it's the default value, we still have a value. + message.OneofString = ""; + Assert.True(accessor.HasValue(message)); + message.OneofString = "hello"; + Assert.True(accessor.HasValue(message)); + message.OneofUint32 = 10; + Assert.False(accessor.HasValue(message)); + } + + [Test] + public void HasValue_Proto3_Primitive_Optional() + { + var message = new TestProto3Optional(); + var accessor = ((IMessage) message).Descriptor.Fields[TestProto3Optional.OptionalInt64FieldNumber].Accessor; + Assert.IsFalse(accessor.HasValue(message)); + message.OptionalInt64 = 5L; + Assert.IsTrue(accessor.HasValue(message)); + message.ClearOptionalInt64(); + Assert.IsFalse(accessor.HasValue(message)); + message.OptionalInt64 = 0L; + Assert.IsTrue(accessor.HasValue(message)); + } + + [Test] + public void HasValue_Proto3_Primitive_NotOptional() { IMessage message = SampleMessages.CreateFullTestAllTypes(); var fields = message.Descriptor.Fields; @@ -106,36 +147,63 @@ namespace Google.Protobuf.Reflection } [Test] - public void HasValue_Proto3Optional() + public void HasValue_Proto3_Repeated() { - IMessage message = new TestProto3Optional - { - OptionalInt32 = 0, - LazyNestedMessage = new TestProto3Optional.Types.NestedMessage() - }; - var fields = message.Descriptor.Fields; - Assert.IsFalse(fields[TestProto3Optional.OptionalInt64FieldNumber].Accessor.HasValue(message)); - Assert.IsFalse(fields[TestProto3Optional.OptionalNestedMessageFieldNumber].Accessor.HasValue(message)); - Assert.IsTrue(fields[TestProto3Optional.LazyNestedMessageFieldNumber].Accessor.HasValue(message)); - Assert.IsTrue(fields[TestProto3Optional.OptionalInt32FieldNumber].Accessor.HasValue(message)); + var message = new TestAllTypes(); + var accessor = ((IMessage) message).Descriptor.Fields[TestProtos.TestAllTypes.RepeatedBoolFieldNumber].Accessor; + Assert.Throws(() => accessor.HasValue(message)); } [Test] - public void HasValue() + public void HasValue_Proto2_Primitive() { - IMessage message = new Proto2.TestAllTypes(); - var fields = message.Descriptor.Fields; - var accessor = fields[Proto2.TestAllTypes.OptionalBoolFieldNumber].Accessor; + var message = new Proto2.TestAllTypes(); + var accessor = ((IMessage) message).Descriptor.Fields[Proto2.TestAllTypes.OptionalInt64FieldNumber].Accessor; + + Assert.IsFalse(accessor.HasValue(message)); + message.OptionalInt64 = 5L; + Assert.IsTrue(accessor.HasValue(message)); + message.ClearOptionalInt64(); + Assert.IsFalse(accessor.HasValue(message)); + message.OptionalInt64 = 0L; + Assert.IsTrue(accessor.HasValue(message)); + } - Assert.False(accessor.HasValue(message)); + [Test] + public void HasValue_Proto2_Message() + { + var message = new Proto2.TestAllTypes(); + var field = ((IMessage) message).Descriptor.Fields[Proto2.TestAllTypes.OptionalForeignMessageFieldNumber]; + Assert.False(field.Accessor.HasValue(message)); + message.OptionalForeignMessage = new Proto2.ForeignMessage(); + Assert.True(field.Accessor.HasValue(message)); + message.OptionalForeignMessage = null; + Assert.False(field.Accessor.HasValue(message)); + } - accessor.SetValue(message, true); + [Test] + public void HasValue_Proto2_Oneof() + { + var message = new Proto2.TestAllTypes(); + var accessor = ((IMessage) message).Descriptor.Fields[Proto2.TestAllTypes.OneofStringFieldNumber].Accessor; + Assert.False(accessor.HasValue(message)); + // Even though it's the default value, we still have a value. + message.OneofString = ""; Assert.True(accessor.HasValue(message)); - - accessor.Clear(message); + message.OneofString = "hello"; + Assert.True(accessor.HasValue(message)); + message.OneofUint32 = 10; Assert.False(accessor.HasValue(message)); } + [Test] + public void HasValue_Proto2_Repeated() + { + var message = new Proto2.TestAllTypes(); + var accessor = ((IMessage) message).Descriptor.Fields[Proto2.TestAllTypes.RepeatedBoolFieldNumber].Accessor; + Assert.Throws(() => accessor.HasValue(message)); + } + [Test] public void SetValue_SingleFields() { @@ -262,6 +330,42 @@ namespace Google.Protobuf.Reflection Assert.Null(message.OptionalNestedMessage); } + [Test] + public void Clear_Proto3_Oneof() + { + var message = new TestAllTypes(); + var accessor = ((IMessage) message).Descriptor.Fields[TestProtos.TestAllTypes.OneofUint32FieldNumber].Accessor; + + // The field accessor Clear method only affects a oneof if the current case is the one being cleared. + message.OneofString = "hello"; + Assert.AreEqual(TestProtos.TestAllTypes.OneofFieldOneofCase.OneofString, message.OneofFieldCase); + accessor.Clear(message); + Assert.AreEqual(TestProtos.TestAllTypes.OneofFieldOneofCase.OneofString, message.OneofFieldCase); + + message.OneofUint32 = 100; + Assert.AreEqual(TestProtos.TestAllTypes.OneofFieldOneofCase.OneofUint32, message.OneofFieldCase); + accessor.Clear(message); + Assert.AreEqual(TestProtos.TestAllTypes.OneofFieldOneofCase.None, message.OneofFieldCase); + } + + [Test] + public void Clear_Proto2_Oneof() + { + var message = new Proto2.TestAllTypes(); + var accessor = ((IMessage) message).Descriptor.Fields[Proto2.TestAllTypes.OneofUint32FieldNumber].Accessor; + + // The field accessor Clear method only affects a oneof if the current case is the one being cleared. + message.OneofString = "hello"; + Assert.AreEqual(Proto2.TestAllTypes.OneofFieldOneofCase.OneofString, message.OneofFieldCase); + accessor.Clear(message); + Assert.AreEqual(Proto2.TestAllTypes.OneofFieldOneofCase.OneofString, message.OneofFieldCase); + + message.OneofUint32 = 100; + Assert.AreEqual(Proto2.TestAllTypes.OneofFieldOneofCase.OneofUint32, message.OneofFieldCase); + accessor.Clear(message); + Assert.AreEqual(Proto2.TestAllTypes.OneofFieldOneofCase.None, message.OneofFieldCase); + } + [Test] public void FieldDescriptor_ByName() { @@ -301,5 +405,32 @@ namespace Google.Protobuf.Reflection message.ClearExtension(RepeatedBoolExtension); Assert.IsNull(message.GetExtension(RepeatedBoolExtension)); } + + [Test] + public void HasPresence() + { + // Proto3 + var fields = TestProtos.TestAllTypes.Descriptor.Fields; + Assert.IsFalse(fields[TestProtos.TestAllTypes.SingleBoolFieldNumber].HasPresence); + Assert.IsTrue(fields[TestProtos.TestAllTypes.OneofBytesFieldNumber].HasPresence); + Assert.IsTrue(fields[TestProtos.TestAllTypes.SingleForeignMessageFieldNumber].HasPresence); + Assert.IsFalse(fields[TestProtos.TestAllTypes.RepeatedBoolFieldNumber].HasPresence); + + fields = TestMap.Descriptor.Fields; + Assert.IsFalse(fields[TestMap.MapBoolBoolFieldNumber].HasPresence); + + fields = TestProto3Optional.Descriptor.Fields; + Assert.IsTrue(fields[TestProto3Optional.OptionalBoolFieldNumber].HasPresence); + + // Proto2 + fields = Proto2.TestAllTypes.Descriptor.Fields; + Assert.IsTrue(fields[Proto2.TestAllTypes.OptionalBoolFieldNumber].HasPresence); + Assert.IsTrue(fields[Proto2.TestAllTypes.OneofBytesFieldNumber].HasPresence); + Assert.IsTrue(fields[Proto2.TestAllTypes.OptionalForeignMessageFieldNumber].HasPresence); + Assert.IsFalse(fields[Proto2.TestAllTypes.RepeatedBoolFieldNumber].HasPresence); + + fields = Proto2.TestRequired.Descriptor.Fields; + Assert.IsTrue(fields[Proto2.TestRequired.AFieldNumber].HasPresence); + } } } diff --git a/csharp/src/Google.Protobuf/Extension.cs b/csharp/src/Google.Protobuf/Extension.cs index a96f8d29b6..6dd1ceaa8e 100644 --- a/csharp/src/Google.Protobuf/Extension.cs +++ b/csharp/src/Google.Protobuf/Extension.cs @@ -55,6 +55,8 @@ namespace Google.Protobuf /// Gets the field number of this extension /// public int FieldNumber { get; } + + internal abstract bool IsRepeated { get; } } /// @@ -79,6 +81,8 @@ namespace Google.Protobuf internal override Type TargetType => typeof(TTarget); + internal override bool IsRepeated => false; + internal override IExtensionValue CreateValue() { return new ExtensionValue(codec); @@ -105,6 +109,8 @@ namespace Google.Protobuf internal override Type TargetType => typeof(TTarget); + internal override bool IsRepeated => true; + internal override IExtensionValue CreateValue() { return new RepeatedExtensionValue(codec); diff --git a/csharp/src/Google.Protobuf/Reflection/FieldDescriptor.cs b/csharp/src/Google.Protobuf/Reflection/FieldDescriptor.cs index 4d87bd7561..3efa0929bb 100644 --- a/csharp/src/Google.Protobuf/Reflection/FieldDescriptor.cs +++ b/csharp/src/Google.Protobuf/Reflection/FieldDescriptor.cs @@ -70,6 +70,21 @@ namespace Google.Protobuf.Reflection /// public string JsonName { get; } + /// + /// Indicates whether this field supports presence, either implicitly (e.g. due to it being a message + /// type field) or explicitly via Has/Clear members. If this returns true, it is safe to call + /// and + /// on this field's accessor with a suitable message. + /// + public bool HasPresence => + Extension != null ? !Extension.IsRepeated + : IsRepeated ? false + : IsMap ? false + : FieldType == FieldType.Message ? true + // This covers "real oneof members" and "proto3 optional fields" + : ContainingOneof != null ? true + : File.Syntax == Syntax.Proto2; + internal FieldDescriptorProto Proto { get; } /// diff --git a/csharp/src/Google.Protobuf/Reflection/SingleFieldAccessor.cs b/csharp/src/Google.Protobuf/Reflection/SingleFieldAccessor.cs index ed844bc51d..07d84d7fb9 100644 --- a/csharp/src/Google.Protobuf/Reflection/SingleFieldAccessor.cs +++ b/csharp/src/Google.Protobuf/Reflection/SingleFieldAccessor.cs @@ -57,63 +57,68 @@ namespace Google.Protobuf.Reflection throw new ArgumentException("Not all required properties/methods available"); } setValueDelegate = ReflectionUtil.CreateActionIMessageObject(property.GetSetMethod()); - if (descriptor.File.Syntax == Syntax.Proto3 && !descriptor.Proto.Proto3Optional) + + // Note: this looks worrying in that we access the containing oneof, which isn't valid until cross-linking + // is complete... but field accessors aren't created until after cross-linking. + // The oneof itself won't be cross-linked yet, but that's okay: the oneof accessor is created + // earlier. + + // Message fields always support presence, via null checks. + if (descriptor.FieldType == FieldType.Message) + { + hasDelegate = message => GetValue(message) != null; + clearDelegate = message => SetValue(message, null); + } + // Oneof fields always support presence, via case checks. + // Note that clearing the field is a no-op unless that specific field is the current "case". + else if (descriptor.RealContainingOneof != null) { - hasDelegate = message => + var oneofAccessor = descriptor.RealContainingOneof.Accessor; + hasDelegate = message => oneofAccessor.GetCaseFieldDescriptor(message) == descriptor; + clearDelegate = message => { - throw new InvalidOperationException("HasValue is not implemented for non-optional proto3 fields"); + // Clear on a field only affects the oneof itself if the current case is the field we're accessing. + if (oneofAccessor.GetCaseFieldDescriptor(message) == descriptor) + { + oneofAccessor.Clear(message); + } }; - var clrType = property.PropertyType; - - // TODO: Validate that this is a reasonable single field? (Should be a value type, a message type, or string/ByteString.) - object defaultValue = - descriptor.FieldType == FieldType.Message ? null - : clrType == typeof(string) ? "" - : clrType == typeof(ByteString) ? ByteString.Empty - : Activator.CreateInstance(clrType); - clearDelegate = message => SetValue(message, defaultValue); } - else + // Primitive fields always support presence in proto2, and support presence in proto3 for optional fields. + else if (descriptor.File.Syntax == Syntax.Proto2 || descriptor.Proto.Proto3Optional) { - // For message fields, just compare with null and set to null. - // For primitive fields, use the Has/Clear methods. - - if (descriptor.FieldType == FieldType.Message) + MethodInfo hasMethod = property.DeclaringType.GetRuntimeProperty("Has" + property.Name).GetMethod; + if (hasMethod == null) { - hasDelegate = message => GetValue(message) != null; - clearDelegate = message => SetValue(message, null); + throw new ArgumentException("Not all required properties/methods are available"); } - else + hasDelegate = ReflectionUtil.CreateFuncIMessageBool(hasMethod); + MethodInfo clearMethod = property.DeclaringType.GetRuntimeMethod("Clear" + property.Name, ReflectionUtil.EmptyTypes); + if (clearMethod == null) { - MethodInfo hasMethod = property.DeclaringType.GetRuntimeProperty("Has" + property.Name).GetMethod; - if (hasMethod == null) - { - throw new ArgumentException("Not all required properties/methods are available"); - } - hasDelegate = ReflectionUtil.CreateFuncIMessageBool(hasMethod); - MethodInfo clearMethod = property.DeclaringType.GetRuntimeMethod("Clear" + property.Name, ReflectionUtil.EmptyTypes); - if (clearMethod == null) - { - throw new ArgumentException("Not all required properties/methods are available"); - } - clearDelegate = ReflectionUtil.CreateActionIMessage(clearMethod); + throw new ArgumentException("Not all required properties/methods are available"); } + clearDelegate = ReflectionUtil.CreateActionIMessage(clearMethod); } - } + // What's left? + // Primitive proto3 fields without the optional keyword, which aren't in oneofs. + else + { + hasDelegate = message => { throw new InvalidOperationException("Presence is not implemented for this field"); }; - public override void Clear(IMessage message) - { - clearDelegate(message); - } + // While presence isn't supported, clearing still is; it's just setting to a default value. + var clrType = property.PropertyType; - public override bool HasValue(IMessage message) - { - return hasDelegate(message); + object defaultValue = + clrType == typeof(string) ? "" + : clrType == typeof(ByteString) ? ByteString.Empty + : Activator.CreateInstance(clrType); + clearDelegate = message => SetValue(message, defaultValue); + } } - public override void SetValue(IMessage message, object value) - { - setValueDelegate(message, value); - } + public override void Clear(IMessage message) => clearDelegate(message); + public override bool HasValue(IMessage message) => hasDelegate(message); + public override void SetValue(IMessage message, object value) => setValueDelegate(message, value); } } diff --git a/java/core/src/test/java/com/google/protobuf/FieldPresenceTest.java b/java/core/src/test/java/com/google/protobuf/FieldPresenceTest.java index 53ec6fe68c..a1c98c0cef 100644 --- a/java/core/src/test/java/com/google/protobuf/FieldPresenceTest.java +++ b/java/core/src/test/java/com/google/protobuf/FieldPresenceTest.java @@ -38,6 +38,7 @@ import com.google.protobuf.Descriptors.OneofDescriptor; import com.google.protobuf.FieldPresenceTestProto.TestAllTypes; import com.google.protobuf.FieldPresenceTestProto.TestOptionalFieldsOnly; import com.google.protobuf.FieldPresenceTestProto.TestRepeatedFieldsOnly; +import com.google.protobuf.testing.proto.TestProto3Optional; import protobuf_unittest.UnittestProto; import junit.framework.TestCase; @@ -101,6 +102,113 @@ public class FieldPresenceTest extends TestCase { UnittestProto.TestAllTypes.Builder.class, TestAllTypes.Builder.class, "OneofBytes"); } + public void testHasMethodForProto3Optional() throws Exception { + assertFalse(TestProto3Optional.getDefaultInstance().hasOptionalInt32()); + assertFalse(TestProto3Optional.getDefaultInstance().hasOptionalInt64()); + assertFalse(TestProto3Optional.getDefaultInstance().hasOptionalUint32()); + assertFalse(TestProto3Optional.getDefaultInstance().hasOptionalUint64()); + assertFalse(TestProto3Optional.getDefaultInstance().hasOptionalSint32()); + assertFalse(TestProto3Optional.getDefaultInstance().hasOptionalSint64()); + assertFalse(TestProto3Optional.getDefaultInstance().hasOptionalFixed32()); + assertFalse(TestProto3Optional.getDefaultInstance().hasOptionalFixed64()); + assertFalse(TestProto3Optional.getDefaultInstance().hasOptionalFloat()); + assertFalse(TestProto3Optional.getDefaultInstance().hasOptionalDouble()); + assertFalse(TestProto3Optional.getDefaultInstance().hasOptionalBool()); + assertFalse(TestProto3Optional.getDefaultInstance().hasOptionalString()); + assertFalse(TestProto3Optional.getDefaultInstance().hasOptionalBytes()); + + TestProto3Optional.Builder builder = TestProto3Optional.newBuilder().setOptionalInt32(0); + assertTrue(builder.hasOptionalInt32()); + assertTrue(builder.build().hasOptionalInt32()); + + TestProto3Optional.Builder otherBuilder = TestProto3Optional.newBuilder().setOptionalInt32(1); + otherBuilder.mergeFrom(builder.build()); + assertTrue(otherBuilder.hasOptionalInt32()); + assertEquals(0, otherBuilder.getOptionalInt32()); + + TestProto3Optional.Builder builder3 = + TestProto3Optional.newBuilder().setOptionalNestedEnumValue(5); + assertTrue(builder3.hasOptionalNestedEnum()); + + TestProto3Optional.Builder builder4 = + TestProto3Optional.newBuilder().setOptionalNestedEnum(TestProto3Optional.NestedEnum.FOO); + assertTrue(builder4.hasOptionalNestedEnum()); + + TestProto3Optional proto = TestProto3Optional.parseFrom(builder.build().toByteArray()); + assertTrue(proto.hasOptionalInt32()); + assertTrue(proto.toBuilder().hasOptionalInt32()); + } + + private static void assertProto3OptionalReflection(String name) throws Exception { + FieldDescriptor fieldDescriptor = TestProto3Optional.getDescriptor().findFieldByName(name); + OneofDescriptor oneofDescriptor = fieldDescriptor.getContainingOneof(); + assertNotNull(fieldDescriptor.getContainingOneof()); + assertTrue(fieldDescriptor.hasOptionalKeyword()); + assertTrue(fieldDescriptor.hasPresence()); + + assertFalse(TestProto3Optional.getDefaultInstance().hasOneof(oneofDescriptor)); + assertNull(TestProto3Optional.getDefaultInstance().getOneofFieldDescriptor(oneofDescriptor)); + + TestProto3Optional.Builder builder = TestProto3Optional.newBuilder(); + builder.setField(fieldDescriptor, fieldDescriptor.getDefaultValue()); + assertTrue(builder.hasField(fieldDescriptor)); + assertEquals(fieldDescriptor.getDefaultValue(), builder.getField(fieldDescriptor)); + assertTrue(builder.build().hasField(fieldDescriptor)); + assertEquals(fieldDescriptor.getDefaultValue(), builder.build().getField(fieldDescriptor)); + assertTrue(builder.hasOneof(oneofDescriptor)); + assertEquals(fieldDescriptor, builder.getOneofFieldDescriptor(oneofDescriptor)); + assertTrue(builder.build().hasOneof(oneofDescriptor)); + assertEquals(fieldDescriptor, builder.build().getOneofFieldDescriptor(oneofDescriptor)); + + TestProto3Optional.Builder otherBuilder = TestProto3Optional.newBuilder(); + otherBuilder.mergeFrom(builder.build()); + assertTrue(otherBuilder.hasField(fieldDescriptor)); + assertEquals(fieldDescriptor.getDefaultValue(), otherBuilder.getField(fieldDescriptor)); + + TestProto3Optional proto = TestProto3Optional.parseFrom(builder.build().toByteArray()); + assertTrue(proto.hasField(fieldDescriptor)); + assertTrue(proto.toBuilder().hasField(fieldDescriptor)); + + DynamicMessage.Builder dynamicBuilder = + DynamicMessage.newBuilder(TestProto3Optional.getDescriptor()); + dynamicBuilder.setField(fieldDescriptor, fieldDescriptor.getDefaultValue()); + assertTrue(dynamicBuilder.hasField(fieldDescriptor)); + assertEquals(fieldDescriptor.getDefaultValue(), dynamicBuilder.getField(fieldDescriptor)); + assertTrue(dynamicBuilder.build().hasField(fieldDescriptor)); + assertEquals( + fieldDescriptor.getDefaultValue(), dynamicBuilder.build().getField(fieldDescriptor)); + assertTrue(dynamicBuilder.hasOneof(oneofDescriptor)); + assertEquals(fieldDescriptor, dynamicBuilder.getOneofFieldDescriptor(oneofDescriptor)); + assertTrue(dynamicBuilder.build().hasOneof(oneofDescriptor)); + assertEquals(fieldDescriptor, dynamicBuilder.build().getOneofFieldDescriptor(oneofDescriptor)); + + DynamicMessage.Builder otherDynamicBuilder = + DynamicMessage.newBuilder(TestProto3Optional.getDescriptor()); + otherDynamicBuilder.mergeFrom(dynamicBuilder.build()); + assertTrue(otherDynamicBuilder.hasField(fieldDescriptor)); + assertEquals(fieldDescriptor.getDefaultValue(), otherDynamicBuilder.getField(fieldDescriptor)); + + DynamicMessage dynamicProto = + DynamicMessage.parseFrom(TestProto3Optional.getDescriptor(), builder.build().toByteArray()); + assertTrue(dynamicProto.hasField(fieldDescriptor)); + assertTrue(dynamicProto.toBuilder().hasField(fieldDescriptor)); + } + + public void testProto3Optional_reflection() throws Exception { + assertProto3OptionalReflection("optional_int32"); + assertProto3OptionalReflection("optional_int64"); + assertProto3OptionalReflection("optional_uint32"); + assertProto3OptionalReflection("optional_uint64"); + assertProto3OptionalReflection("optional_sint32"); + assertProto3OptionalReflection("optional_sint64"); + assertProto3OptionalReflection("optional_fixed32"); + assertProto3OptionalReflection("optional_fixed64"); + assertProto3OptionalReflection("optional_float"); + assertProto3OptionalReflection("optional_double"); + assertProto3OptionalReflection("optional_bool"); + assertProto3OptionalReflection("optional_string"); + assertProto3OptionalReflection("optional_bytes"); + } public void testOneofEquals() throws Exception { TestAllTypes.Builder builder = TestAllTypes.newBuilder(); diff --git a/java/core/src/test/java/com/google/protobuf/TextFormatTest.java b/java/core/src/test/java/com/google/protobuf/TextFormatTest.java index 50ba0168da..6ca3ae111f 100644 --- a/java/core/src/test/java/com/google/protobuf/TextFormatTest.java +++ b/java/core/src/test/java/com/google/protobuf/TextFormatTest.java @@ -40,6 +40,8 @@ import com.google.protobuf.Descriptors.Descriptor; import com.google.protobuf.Descriptors.FieldDescriptor; import com.google.protobuf.Descriptors.FileDescriptor; import com.google.protobuf.TextFormat.Parser.SingularOverwritePolicy; +import com.google.protobuf.testing.proto.TestProto3Optional; +import com.google.protobuf.testing.proto.TestProto3Optional.NestedEnum; import any_test.AnyTestProto.TestAny; import map_test.MapTestProto.TestMap; import protobuf_unittest.UnittestMset.TestMessageSetExtension1; @@ -319,6 +321,24 @@ public class TextFormatTest extends TestCase { assertEquals(canonicalExoticText, message.toString()); } + public void testRoundtripProto3Optional() throws Exception { + Message message = + TestProto3Optional.newBuilder() + .setOptionalInt32(1) + .setOptionalInt64(2) + .setOptionalNestedEnum(NestedEnum.BAZ) + .build(); + TestProto3Optional.Builder message2 = TestProto3Optional.newBuilder(); + TextFormat.merge(message.toString(), message2); + + assertTrue(message2.hasOptionalInt32()); + assertTrue(message2.hasOptionalInt64()); + assertTrue(message2.hasOptionalNestedEnum()); + assertEquals(1, message2.getOptionalInt32()); + assertEquals(2, message2.getOptionalInt64()); + assertEquals(NestedEnum.BAZ, message2.getOptionalNestedEnum()); + } + public void testPrintMessageSet() throws Exception { TestMessageSet messageSet = TestMessageSet.newBuilder() diff --git a/js/experimental/runtime/kernel/conformance/conformance_testee.js b/js/experimental/runtime/kernel/conformance/conformance_testee.js old mode 100644 new mode 100755 diff --git a/js/experimental/runtime/kernel/conformance/conformance_testee_runner_node.js b/js/experimental/runtime/kernel/conformance/conformance_testee_runner_node.js old mode 100644 new mode 100755 diff --git a/js/experimental/runtime/kernel/indexer.js b/js/experimental/runtime/kernel/indexer.js index d4dea13515..205a34e44d 100644 --- a/js/experimental/runtime/kernel/indexer.js +++ b/js/experimental/runtime/kernel/indexer.js @@ -8,7 +8,8 @@ const BinaryStorage = goog.require('protobuf.runtime.BinaryStorage'); const BufferDecoder = goog.require('protobuf.binary.BufferDecoder'); const WireType = goog.require('protobuf.binary.WireType'); const {Field} = goog.require('protobuf.binary.field'); -const {checkCriticalElementIndex, checkCriticalState} = goog.require('protobuf.internal.checks'); +const {checkCriticalState} = goog.require('protobuf.internal.checks'); +const {skipField, tagToFieldNumber, tagToWireType} = goog.require('protobuf.binary.tag'); /** * Appends a new entry in the index array for the given field number. @@ -26,26 +27,6 @@ function addIndexEntry(storage, fieldNumber, wireType, startIndex) { } } -/** - * Returns wire type stored in a tag. - * Protos store the wire type as the first 3 bit of a tag. - * @param {number} tag - * @return {!WireType} - */ -function tagToWireType(tag) { - return /** @type {!WireType} */ (tag & 0x07); -} - -/** - * Returns the field number stored in a tag. - * Protos store the field number in the upper 29 bits of a 32 bit number. - * @param {number} tag - * @return {number} - */ -function tagToFieldNumber(tag) { - return tag >>> 3; -} - /** * Creates an index of field locations in a given binary protobuf. * @param {!BufferDecoder} bufferDecoder @@ -62,83 +43,13 @@ function buildIndex(bufferDecoder, pivot) { const wireType = tagToWireType(tag); const fieldNumber = tagToFieldNumber(tag); checkCriticalState(fieldNumber > 0, `Invalid field number ${fieldNumber}`); - addIndexEntry(storage, fieldNumber, wireType, bufferDecoder.cursor()); - - checkCriticalState( - !skipField_(bufferDecoder, wireType, fieldNumber), - 'Found unmatched stop group.'); + skipField(bufferDecoder, wireType, fieldNumber); } return storage; } -/** - * Skips over fields until the next field of the message. - * @param {!BufferDecoder} bufferDecoder - * @param {!WireType} wireType - * @param {number} fieldNumber - * @return {boolean} Whether the field we skipped over was a stop group. - * @private - */ -function skipField_(bufferDecoder, wireType, fieldNumber) { - switch (wireType) { - case WireType.VARINT: - checkCriticalElementIndex( - bufferDecoder.cursor(), bufferDecoder.endIndex()); - bufferDecoder.skipVarint(); - return false; - case WireType.FIXED64: - bufferDecoder.skip(8); - return false; - case WireType.DELIMITED: - checkCriticalElementIndex( - bufferDecoder.cursor(), bufferDecoder.endIndex()); - const length = bufferDecoder.getUnsignedVarint32(); - bufferDecoder.skip(length); - return false; - case WireType.START_GROUP: - checkCriticalState( - skipGroup_(bufferDecoder, fieldNumber), 'No end group found.'); - return false; - case WireType.END_GROUP: - // Signal that we found a stop group to the caller - return true; - case WireType.FIXED32: - bufferDecoder.skip(4); - return false; - default: - throw new Error(`Invalid wire type: ${wireType}`); - } -} - -/** - * Skips over fields until it finds the end of a given group. - * @param {!BufferDecoder} bufferDecoder - * @param {number} groupFieldNumber - * @return {boolean} Returns true if an end was found. - * @private - */ -function skipGroup_(bufferDecoder, groupFieldNumber) { - // On a start group we need to keep skipping fields until we find a - // corresponding stop group - // Note: Since we are calling skipField from here nested groups will be - // handled by recursion of this method and thus we will not see a nested - // STOP GROUP here unless there is something wrong with the input data. - while (bufferDecoder.hasNext()) { - const tag = bufferDecoder.getUnsignedVarint32(); - const wireType = tagToWireType(tag); - const fieldNumber = tagToFieldNumber(tag); - - if (skipField_(bufferDecoder, wireType, fieldNumber)) { - checkCriticalState( - groupFieldNumber === fieldNumber, - `Expected stop group for fieldnumber ${groupFieldNumber} not found.`); - return true; - } - } - return false; -} - exports = { buildIndex, + tagToWireType, }; diff --git a/js/experimental/runtime/kernel/indexer_test.js b/js/experimental/runtime/kernel/indexer_test.js index 101e44283f..ffb8807994 100644 --- a/js/experimental/runtime/kernel/indexer_test.js +++ b/js/experimental/runtime/kernel/indexer_test.js @@ -72,12 +72,12 @@ describe('Indexer does', () => { it('fail for invalid wire type (6)', () => { expect(() => buildIndex(createBufferDecoder(0x0E, 0x01), PIVOT)) - .toThrowError('Invalid wire type: 6'); + .toThrowError('Unexpected wire type: 6'); }); it('fail for invalid wire type (7)', () => { expect(() => buildIndex(createBufferDecoder(0x0F, 0x01), PIVOT)) - .toThrowError('Invalid wire type: 7'); + .toThrowError('Unexpected wire type: 7'); }); it('index varint', () => { @@ -269,33 +269,8 @@ describe('Indexer does', () => { it('fail on unmatched stop group', () => { const data = createBufferDecoder(0x0C, 0x01); - if (CHECK_CRITICAL_STATE) { - expect(() => buildIndex(data, PIVOT)) - .toThrowError('Found unmatched stop group.'); - } else { - // Note in unchecked mode we produce invalid output for invalid inputs. - // This test just documents our behavior in those cases. - // These values might change at any point and are not considered - // what the implementation should be doing here. - const storage = buildIndex(data, PIVOT); - - expect(getStorageSize(storage)).toBe(1); - const entryArray = storage.get(1).getIndexArray(); - expect(entryArray).not.toBeUndefined(); - expect(entryArray.length).toBe(1); - const entry = entryArray[0]; - - expect(Field.getWireType(entry)).toBe(WireType.END_GROUP); - expect(Field.getStartIndex(entry)).toBe(1); - - const entryArray2 = storage.get(0).getIndexArray(); - expect(entryArray2).not.toBeUndefined(); - expect(entryArray2.length).toBe(1); - const entry2 = entryArray2[0]; - - expect(Field.getWireType(entry2)).toBe(WireType.FIXED64); - expect(Field.getStartIndex(entry2)).toBe(2); - } + expect(() => buildIndex(data, PIVOT)) + .toThrowError('Unexpected wire type: 4'); }); it('fail for groups without matching stop group', () => { diff --git a/js/experimental/runtime/kernel/kernel.js b/js/experimental/runtime/kernel/kernel.js index 2a2b2d8c48..bb2608398a 100644 --- a/js/experimental/runtime/kernel/kernel.js +++ b/js/experimental/runtime/kernel/kernel.js @@ -26,6 +26,7 @@ const reader = goog.require('protobuf.binary.reader'); const {CHECK_TYPE, checkCriticalElementIndex, checkCriticalState, checkCriticalType, checkCriticalTypeBool, checkCriticalTypeBoolArray, checkCriticalTypeByteString, checkCriticalTypeByteStringArray, checkCriticalTypeDouble, checkCriticalTypeDoubleArray, checkCriticalTypeFloat, checkCriticalTypeFloatIterable, checkCriticalTypeMessageArray, checkCriticalTypeSignedInt32, checkCriticalTypeSignedInt32Array, checkCriticalTypeSignedInt64, checkCriticalTypeSignedInt64Array, checkCriticalTypeString, checkCriticalTypeStringArray, checkCriticalTypeUnsignedInt32, checkCriticalTypeUnsignedInt32Array, checkDefAndNotNull, checkElementIndex, checkFieldNumber, checkFunctionExists, checkState, checkTypeDouble, checkTypeFloat, checkTypeSignedInt32, checkTypeSignedInt64, checkTypeUnsignedInt32} = goog.require('protobuf.internal.checks'); const {Field, IndexEntry} = goog.require('protobuf.binary.field'); const {buildIndex} = goog.require('protobuf.binary.indexer'); +const {createTag, get32BitVarintLength, getTagLength} = goog.require('protobuf.binary.tag'); /** @@ -139,6 +140,28 @@ function readRepeatedNonPrimitive(indexArray, bufferDecoder, singularReadFunc) { return result; } +/** + * Converts all entries of the index array to the template type using the given + * read function and return an Array containing those converted values. This is + * used to implement parsing repeated non-primitive fields. + * @param {!Array} indexArray + * @param {!BufferDecoder} bufferDecoder + * @param {number} fieldNumber + * @param {function(!Kernel):T} instanceCreator + * @param {number=} pivot + * @return {!Array} + * @template T + */ +function readRepeatedGroup( + indexArray, bufferDecoder, fieldNumber, instanceCreator, pivot) { + const result = new Array(indexArray.length); + for (let i = 0; i < indexArray.length; i++) { + result[i] = doReadGroup( + bufferDecoder, indexArray[i], fieldNumber, instanceCreator, pivot); + } + return result; +} + /** * Creates a new bytes array to contain all data of a submessage. * When there are multiple entries, merge them together. @@ -193,6 +216,51 @@ function readMessage(indexArray, bufferDecoder, instanceCreator, pivot) { return instanceCreator(accessor); } +/** + * Merges all index entries of the index array using the given read function. + * This is used to implement parsing singular group fields. + * @param {!Array} indexArray + * @param {!BufferDecoder} bufferDecoder + * @param {number} fieldNumber + * @param {function(!Kernel):T} instanceCreator + * @param {number=} pivot + * @return {T} + * @template T + */ +function readGroup( + indexArray, bufferDecoder, fieldNumber, instanceCreator, pivot) { + checkInstanceCreator(instanceCreator); + checkState(indexArray.length > 0); + return doReadGroup( + bufferDecoder, indexArray[indexArray.length - 1], fieldNumber, + instanceCreator, pivot); +} + +/** + * Merges all index entries of the index array using the given read function. + * This is used to implement parsing singular message fields. + * @param {!BufferDecoder} bufferDecoder + * @param {!IndexEntry} indexEntry + * @param {number} fieldNumber + * @param {function(!Kernel):T} instanceCreator + * @param {number=} pivot + * @return {T} + * @template T + */ +function doReadGroup( + bufferDecoder, indexEntry, fieldNumber, instanceCreator, pivot) { + validateWireType(indexEntry, WireType.START_GROUP); + const fieldStartIndex = Field.getStartIndex(indexEntry); + const tag = createTag(WireType.START_GROUP, fieldNumber); + const groupTagLength = get32BitVarintLength(tag); + const groupLength = getTagLength( + bufferDecoder, fieldStartIndex, WireType.START_GROUP, fieldNumber); + const accessorBuffer = bufferDecoder.subBufferDecoder( + fieldStartIndex, groupLength - groupTagLength); + const kernel = Kernel.fromBufferDecoder_(accessorBuffer, pivot); + return instanceCreator(kernel); +} + /** * @param {!Writer} writer * @param {number} fieldNumber @@ -203,6 +271,18 @@ function writeMessage(writer, fieldNumber, value) { fieldNumber, checkDefAndNotNull(value).internalGetKernel().serialize()); } +/** + * @param {!Writer} writer + * @param {number} fieldNumber + * @param {?InternalMessage} value + */ +function writeGroup(writer, fieldNumber, value) { + const kernel = checkDefAndNotNull(value).internalGetKernel(); + writer.writeStartGroup(fieldNumber); + kernel.serializeToWriter(writer); + writer.writeEndGroup(fieldNumber); +} + /** * Writes the array of Messages into the writer for the given field number. * @param {!Writer} writer @@ -215,6 +295,18 @@ function writeRepeatedMessage(writer, fieldNumber, values) { } } +/** + * Writes the array of Messages into the writer for the given field number. + * @param {!Writer} writer + * @param {number} fieldNumber + * @param {!Array} values + */ +function writeRepeatedGroup(writer, fieldNumber, values) { + for (const value of values) { + writeGroup(writer, fieldNumber, value); + } +} + /** * Array.from has a weird type definition in google3/javascript/externs/es6.js * and wants the mapping function accept strings. @@ -406,7 +498,8 @@ class Kernel { writer.writeTag(fieldNumber, Field.getWireType(indexEntry)); writer.writeBufferDecoder( checkDefAndNotNull(this.bufferDecoder_), - Field.getStartIndex(indexEntry), Field.getWireType(indexEntry)); + Field.getStartIndex(indexEntry), Field.getWireType(indexEntry), + fieldNumber); } } }); @@ -690,6 +783,28 @@ class Kernel { writeMessage); } + /** + * Returns data as a mutable proto Message for the given field number. + * If no value has been set, return null. + * If hasFieldNumber(fieldNumber) == false before calling, it remains false. + * + * This method should not be used along with getMessage, since calling + * getMessageOrNull after getMessage will not register the encoder. + * + * @param {number} fieldNumber + * @param {function(!Kernel):T} instanceCreator + * @param {number=} pivot + * @return {?T} + * @template T + */ + getGroupOrNull(fieldNumber, instanceCreator, pivot = undefined) { + return this.getFieldWithDefault_( + fieldNumber, null, + (indexArray, bytes) => + readGroup(indexArray, bytes, fieldNumber, instanceCreator, pivot), + writeGroup); + } + /** * Returns data as a mutable proto Message for the given field number. * If no value has been set previously, creates and attaches an instance. @@ -714,6 +829,30 @@ class Kernel { return instance; } + /** + * Returns data as a mutable proto Message for the given field number. + * If no value has been set previously, creates and attaches an instance. + * Postcondition: hasFieldNumber(fieldNumber) == true. + * + * This method should not be used along with getMessage, since calling + * getMessageAttach after getMessage will not register the encoder. + * + * @param {number} fieldNumber + * @param {function(!Kernel):T} instanceCreator + * @param {number=} pivot + * @return {T} + * @template T + */ + getGroupAttach(fieldNumber, instanceCreator, pivot = undefined) { + checkInstanceCreator(instanceCreator); + let instance = this.getGroupOrNull(fieldNumber, instanceCreator, pivot); + if (!instance) { + instance = instanceCreator(Kernel.createEmpty()); + this.setField_(fieldNumber, instance, writeGroup); + } + return instance; + } + /** * Returns data as a proto Message for the given field number. * If no value has been set, return a default instance. @@ -744,6 +883,36 @@ class Kernel { return message === null ? instanceCreator(Kernel.createEmpty()) : message; } + /** + * Returns data as a proto Message for the given field number. + * If no value has been set, return a default instance. + * This default instance is guaranteed to be the same instance, unless this + * field is cleared. + * Does not register the encoder, so changes made to the returned + * sub-message will not be included when serializing the parent message. + * Use getMessageAttach() if the resulting sub-message should be mutable. + * + * This method should not be used along with getMessageOrNull or + * getMessageAttach, since these methods register the encoder. + * + * @param {number} fieldNumber + * @param {function(!Kernel):T} instanceCreator + * @param {number=} pivot + * @return {T} + * @template T + */ + getGroup(fieldNumber, instanceCreator, pivot = undefined) { + checkInstanceCreator(instanceCreator); + const message = this.getFieldWithDefault_( + fieldNumber, null, + (indexArray, bytes) => + readGroup(indexArray, bytes, fieldNumber, instanceCreator, pivot)); + // Returns an empty message as the default value if the field doesn't exist. + // We don't pass the default value to getFieldWithDefault_ to reduce object + // allocation. + return message === null ? instanceCreator(Kernel.createEmpty()) : message; + } + /** * Returns the accessor for the given singular message, or returns null if * it hasn't been set. @@ -1614,6 +1783,71 @@ class Kernel { .length; } + /** + * Returns an Array instance containing boolean values for the given field + * number. + * @param {number} fieldNumber + * @param {function(!Kernel):T} instanceCreator + * @param {number|undefined} pivot + * @return {!Array} + * @template T + * @private + */ + getRepeatedGroupArray_(fieldNumber, instanceCreator, pivot) { + return this.getFieldWithDefault_( + fieldNumber, [], + (indexArray, bufferDecoder) => readRepeatedGroup( + indexArray, bufferDecoder, fieldNumber, instanceCreator, pivot), + writeRepeatedGroup); + } + + /** + * Returns the element at index for the given field number as a group. + * @param {number} fieldNumber + * @param {function(!Kernel):T} instanceCreator + * @param {number} index + * @param {number=} pivot + * @return {T} + * @template T + */ + getRepeatedGroupElement( + fieldNumber, instanceCreator, index, pivot = undefined) { + const array = + this.getRepeatedGroupArray_(fieldNumber, instanceCreator, pivot); + checkCriticalElementIndex(index, array.length); + return array[index]; + } + + /** + * Returns an Iterable instance containing group values for the given field + * number. + * @param {number} fieldNumber + * @param {function(!Kernel):T} instanceCreator + * @param {number=} pivot + * @return {!Iterable} + * @template T + */ + getRepeatedGroupIterable(fieldNumber, instanceCreator, pivot = undefined) { + // Don't split this statement unless needed. JS compiler thinks + // getRepeatedMessageArray_ might have side effects and doesn't inline the + // call in the compiled code. See cl/293894484 for details. + return new ArrayIterable( + this.getRepeatedGroupArray_(fieldNumber, instanceCreator, pivot)); + } + + /** + * Returns the size of the repeated field. + * @param {number} fieldNumber + * @param {function(!Kernel):T} instanceCreator + * @return {number} + * @param {number=} pivot + * @template T + */ + getRepeatedGroupSize(fieldNumber, instanceCreator, pivot = undefined) { + return this.getRepeatedGroupArray_(fieldNumber, instanceCreator, pivot) + .length; + } + /*************************************************************************** * OPTIONAL SETTER METHODS ***************************************************************************/ @@ -1801,6 +2035,20 @@ class Kernel { this.setInt64(fieldNumber, value); } + /** + * Sets a proto Group to the field with the given field number. + * Instead of working with the Kernel inside of the message directly, we + * need the message instance to keep its reference equality for subsequent + * gettings. + * @param {number} fieldNumber + * @param {!InternalMessage} value + */ + setGroup(fieldNumber, value) { + checkCriticalType( + value !== null, 'Given value is not a message instance: null'); + this.setField_(fieldNumber, value, writeGroup); + } + /** * Sets a proto Message to the field with the given field number. * Instead of working with the Kernel inside of the message directly, we @@ -3806,6 +4054,69 @@ class Kernel { this.addRepeatedMessageIterable( fieldNumber, [value], instanceCreator, pivot); } + + // Groups + /** + * Sets all message values into the field for the given field number. + * @param {number} fieldNumber + * @param {!Iterable} values + */ + setRepeatedGroupIterable(fieldNumber, values) { + const /** !Array */ array = Array.from(values); + checkCriticalTypeMessageArray(array); + this.setField_(fieldNumber, array, writeRepeatedGroup); + } + + /** + * Adds all message values into the field for the given field number. + * @param {number} fieldNumber + * @param {!Iterable} values + * @param {function(!Kernel):!InternalMessage} instanceCreator + * @param {number=} pivot + */ + addRepeatedGroupIterable( + fieldNumber, values, instanceCreator, pivot = undefined) { + const array = [ + ...this.getRepeatedGroupArray_(fieldNumber, instanceCreator, pivot), + ...values, + ]; + checkCriticalTypeMessageArray(array); + // Needs to set it back with the new array. + this.setField_(fieldNumber, array, writeRepeatedGroup); + } + + /** + * Sets a single message value into the field for the given field number at + * the given index. + * @param {number} fieldNumber + * @param {!InternalMessage} value + * @param {function(!Kernel):!InternalMessage} instanceCreator + * @param {number} index + * @param {number=} pivot + * @throws {!Error} if index is out of range when check mode is critical + */ + setRepeatedGroupElement( + fieldNumber, value, instanceCreator, index, pivot = undefined) { + checkInstanceCreator(instanceCreator); + checkCriticalType( + value !== null, 'Given value is not a message instance: null'); + const array = + this.getRepeatedGroupArray_(fieldNumber, instanceCreator, pivot); + checkCriticalElementIndex(index, array.length); + array[index] = value; + } + + /** + * Adds a single message value into the field for the given field number. + * @param {number} fieldNumber + * @param {!InternalMessage} value + * @param {function(!Kernel):!InternalMessage} instanceCreator + * @param {number=} pivot + */ + addRepeatedGroupElement( + fieldNumber, value, instanceCreator, pivot = undefined) { + this.addRepeatedGroupIterable(fieldNumber, [value], instanceCreator, pivot); + } } exports = Kernel; diff --git a/js/experimental/runtime/kernel/kernel_repeated_test.js b/js/experimental/runtime/kernel/kernel_repeated_test.js index 7741c35c38..6a798b6c09 100644 --- a/js/experimental/runtime/kernel/kernel_repeated_test.js +++ b/js/experimental/runtime/kernel/kernel_repeated_test.js @@ -7425,3 +7425,383 @@ describe('Kernel for repeated message does', () => { } }); }); + +describe('Kernel for repeated groups does', () => { + it('return empty array for the empty input', () => { + const accessor = Kernel.createEmpty(); + expectEqualToArray( + accessor.getRepeatedGroupIterable(1, TestMessage.instanceCreator), []); + }); + + it('ensure not the same instance returned for the empty input', () => { + const accessor = Kernel.createEmpty(); + const list1 = + accessor.getRepeatedGroupIterable(1, TestMessage.instanceCreator); + const list2 = + accessor.getRepeatedGroupIterable(1, TestMessage.instanceCreator); + expect(list1).not.toBe(list2); + }); + + it('return size for the empty input', () => { + const accessor = Kernel.createEmpty(); + expect(accessor.getRepeatedGroupSize(1, TestMessage.instanceCreator)) + .toEqual(0); + }); + + it('return values from the input', () => { + const bytes1 = createArrayBuffer(0x08, 0x01); + const bytes2 = createArrayBuffer(0x08, 0x02); + const msg1 = new TestMessage(Kernel.fromArrayBuffer(bytes1)); + const msg2 = new TestMessage(Kernel.fromArrayBuffer(bytes2)); + + const bytes = + createArrayBuffer(0x0B, 0x08, 0x01, 0x0C, 0x0B, 0x08, 0x02, 0x0C); + const accessor = Kernel.fromArrayBuffer(bytes); + expectEqualToMessageArray( + accessor.getRepeatedGroupIterable(1, TestMessage.instanceCreator), + [msg1, msg2]); + }); + + it('ensure not the same array instance returned', () => { + const bytes = + createArrayBuffer(0x0B, 0x08, 0x01, 0x0C, 0x0B, 0x08, 0x02, 0x0C); + const accessor = Kernel.fromArrayBuffer(bytes); + const list1 = + accessor.getRepeatedGroupIterable(1, TestMessage.instanceCreator); + const list2 = + accessor.getRepeatedGroupIterable(1, TestMessage.instanceCreator); + expect(list1).not.toBe(list2); + }); + + it('ensure the same array element returned for get iterable', () => { + const bytes = + createArrayBuffer(0x0B, 0x08, 0x01, 0x0C, 0x0B, 0x08, 0x02, 0x0C); + const accessor = Kernel.fromArrayBuffer(bytes); + const list1 = + accessor.getRepeatedGroupIterable(1, TestMessage.instanceCreator); + const list2 = accessor.getRepeatedGroupIterable( + 1, TestMessage.instanceCreator, /* pivot= */ 0); + const array1 = Array.from(list1); + const array2 = Array.from(list2); + for (let i = 0; i < array1.length; i++) { + expect(array1[i]).toBe(array2[i]); + } + }); + + it('return accessors from the input', () => { + const bytes = + createArrayBuffer(0x0B, 0x08, 0x01, 0x0C, 0x0B, 0x08, 0x02, 0x0C); + const accessor = Kernel.fromArrayBuffer(bytes); + const [accessor1, accessor2] = + [...accessor.getRepeatedGroupIterable(1, TestMessage.instanceCreator)]; + expect(accessor1.getInt32WithDefault(1)).toEqual(1); + expect(accessor2.getInt32WithDefault(1)).toEqual(2); + }); + + it('return accessors from the input when pivot is set', () => { + const bytes = + createArrayBuffer(0x0B, 0x08, 0x01, 0x0C, 0x0B, 0x08, 0x02, 0x0C); + const accessor = Kernel.fromArrayBuffer(bytes); + const [accessor1, accessor2] = [...accessor.getRepeatedGroupIterable( + 1, TestMessage.instanceCreator, /* pivot= */ 0)]; + expect(accessor1.getInt32WithDefault(1)).toEqual(1); + expect(accessor2.getInt32WithDefault(1)).toEqual(2); + }); + + it('return the repeated field element from the input', () => { + const bytes = + createArrayBuffer(0x0B, 0x08, 0x01, 0x0C, 0x0B, 0x08, 0x02, 0x0C); + const accessor = Kernel.fromArrayBuffer(bytes); + const msg1 = accessor.getRepeatedGroupElement( + /* fieldNumber= */ 1, TestMessage.instanceCreator, + /* index= */ 0); + const msg2 = accessor.getRepeatedGroupElement( + /* fieldNumber= */ 1, TestMessage.instanceCreator, + /* index= */ 1, /* pivot= */ 0); + expect(msg1.getInt32WithDefault( + /* fieldNumber= */ 1, /* default= */ 0)) + .toEqual(1); + expect(msg2.getInt32WithDefault( + /* fieldNumber= */ 1, /* default= */ 0)) + .toEqual(2); + }); + + it('ensure the same array element returned', () => { + const bytes = + createArrayBuffer(0x0B, 0x08, 0x01, 0x0C, 0x0B, 0x08, 0x02, 0x0C); + const accessor = Kernel.fromArrayBuffer(bytes); + const msg1 = accessor.getRepeatedGroupElement( + /* fieldNumber= */ 1, TestMessage.instanceCreator, + /* index= */ 0); + const msg2 = accessor.getRepeatedGroupElement( + /* fieldNumber= */ 1, TestMessage.instanceCreator, + /* index= */ 0); + expect(msg1).toBe(msg2); + }); + + it('return the size from the input', () => { + const bytes = + createArrayBuffer(0x0B, 0x08, 0x01, 0x0C, 0x0B, 0x08, 0x02, 0x0C); + const accessor = Kernel.fromArrayBuffer(bytes); + expect(accessor.getRepeatedGroupSize(1, TestMessage.instanceCreator)) + .toEqual(2); + }); + + it('encode repeated message from the input', () => { + const bytes = + createArrayBuffer(0x0B, 0x08, 0x01, 0x0C, 0x0B, 0x08, 0x02, 0x0C); + const accessor = Kernel.fromArrayBuffer(bytes); + expect(accessor.serialize()).toEqual(bytes); + }); + + it('add a single value', () => { + const accessor = Kernel.createEmpty(); + const bytes1 = createArrayBuffer(0x08, 0x01); + const msg1 = new TestMessage(Kernel.fromArrayBuffer(bytes1)); + const bytes2 = createArrayBuffer(0x08, 0x02); + const msg2 = new TestMessage(Kernel.fromArrayBuffer(bytes2)); + + accessor.addRepeatedGroupElement(1, msg1, TestMessage.instanceCreator); + accessor.addRepeatedGroupElement(1, msg2, TestMessage.instanceCreator); + const result = + accessor.getRepeatedGroupIterable(1, TestMessage.instanceCreator); + + expect(Array.from(result)).toEqual([msg1, msg2]); + }); + + it('add values', () => { + const accessor = Kernel.createEmpty(); + const bytes1 = createArrayBuffer(0x08, 0x01); + const msg1 = new TestMessage(Kernel.fromArrayBuffer(bytes1)); + const bytes2 = createArrayBuffer(0x08, 0x02); + const msg2 = new TestMessage(Kernel.fromArrayBuffer(bytes2)); + + accessor.addRepeatedGroupIterable(1, [msg1], TestMessage.instanceCreator); + accessor.addRepeatedGroupIterable(1, [msg2], TestMessage.instanceCreator); + const result = + accessor.getRepeatedGroupIterable(1, TestMessage.instanceCreator); + + expect(Array.from(result)).toEqual([msg1, msg2]); + }); + + it('set a single value', () => { + const bytes = createArrayBuffer(0x0B, 0x08, 0x01, 0x0C); + const accessor = Kernel.fromArrayBuffer(bytes); + const subbytes = createArrayBuffer(0x08, 0x01); + const submsg = new TestMessage(Kernel.fromArrayBuffer(subbytes)); + + accessor.setRepeatedGroupElement( + /* fieldNumber= */ 1, submsg, TestMessage.instanceCreator, + /* index= */ 0); + const result = + accessor.getRepeatedGroupIterable(1, TestMessage.instanceCreator); + + expect(Array.from(result)).toEqual([submsg]); + }); + + it('write submessage changes made via getRepeatedGroupElement', () => { + const bytes = createArrayBuffer(0x0B, 0x08, 0x05, 0x0C); + const expected = createArrayBuffer(0x0B, 0x08, 0x00, 0x0C); + const accessor = Kernel.fromArrayBuffer(bytes); + const submsg = accessor.getRepeatedGroupElement( + /* fieldNumber= */ 1, TestMessage.instanceCreator, + /* index= */ 0); + expect(submsg.getInt32WithDefault(1, 0)).toEqual(5); + submsg.setInt32(1, 0); + + expect(accessor.serialize()).toEqual(expected); + }); + + it('set values', () => { + const accessor = Kernel.createEmpty(); + const subbytes = createArrayBuffer(0x08, 0x01); + const submsg = new TestMessage(Kernel.fromArrayBuffer(subbytes)); + + accessor.setRepeatedGroupIterable(1, [submsg]); + const result = + accessor.getRepeatedGroupIterable(1, TestMessage.instanceCreator); + + expect(Array.from(result)).toEqual([submsg]); + }); + + it('encode for adding single value', () => { + const accessor = Kernel.createEmpty(); + const bytes1 = createArrayBuffer(0x08, 0x01); + const msg1 = new TestMessage(Kernel.fromArrayBuffer(bytes1)); + const bytes2 = createArrayBuffer(0x08, 0x00); + const msg2 = new TestMessage(Kernel.fromArrayBuffer(bytes2)); + const expected = + createArrayBuffer(0x0B, 0x08, 0x01, 0x0C, 0x0B, 0x08, 0x00, 0x0C); + + accessor.addRepeatedGroupElement(1, msg1, TestMessage.instanceCreator); + accessor.addRepeatedGroupElement(1, msg2, TestMessage.instanceCreator); + const result = accessor.serialize(); + + expect(result).toEqual(expected); + }); + + it('encode for adding values', () => { + const accessor = Kernel.createEmpty(); + const bytes1 = createArrayBuffer(0x08, 0x01); + const msg1 = new TestMessage(Kernel.fromArrayBuffer(bytes1)); + const bytes2 = createArrayBuffer(0x08, 0x00); + const msg2 = new TestMessage(Kernel.fromArrayBuffer(bytes2)); + const expected = + createArrayBuffer(0x0B, 0x08, 0x01, 0x0C, 0x0B, 0x08, 0x00, 0x0C); + + accessor.addRepeatedGroupIterable( + 1, [msg1, msg2], TestMessage.instanceCreator); + const result = accessor.serialize(); + + expect(result).toEqual(expected); + }); + + it('encode for setting single value', () => { + const bytes = createArrayBuffer(0x0B, 0x08, 0x00, 0x0C); + const accessor = Kernel.fromArrayBuffer(bytes); + const subbytes = createArrayBuffer(0x08, 0x01); + const submsg = new TestMessage(Kernel.fromArrayBuffer(subbytes)); + const expected = createArrayBuffer(0x0B, 0x08, 0x01, 0x0C); + + accessor.setRepeatedGroupElement( + /* fieldNumber= */ 1, submsg, TestMessage.instanceCreator, + /* index= */ 0); + const result = accessor.serialize(); + + expect(result).toEqual(expected); + }); + + it('encode for setting values', () => { + const accessor = Kernel.createEmpty(); + const subbytes = createArrayBuffer(0x08, 0x01); + const submsg = new TestMessage(Kernel.fromArrayBuffer(subbytes)); + const expected = createArrayBuffer(0x0B, 0x08, 0x01, 0x0C); + + accessor.setRepeatedGroupIterable(1, [submsg]); + const result = accessor.serialize(); + + expect(result).toEqual(expected); + }); + + it('fail when getting groups value with other wire types', () => { + const accessor = Kernel.fromArrayBuffer(createArrayBuffer( + 0x09, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00)); + + if (CHECK_CRITICAL_STATE) { + expect(() => { + accessor.getRepeatedGroupIterable(1, TestMessage.instanceCreator); + }).toThrow(); + } + }); + + it('fail when adding group values with wrong type value', () => { + const accessor = Kernel.createEmpty(); + const fakeValue = /** @type {!TestMessage} */ (/** @type {*} */ (null)); + if (CHECK_CRITICAL_STATE) { + expect( + () => accessor.addRepeatedGroupIterable( + 1, [fakeValue], TestMessage.instanceCreator)) + .toThrowError('Given value is not a message instance: null'); + } else { + // Note in unchecked mode we produce invalid output for invalid inputs. + // This test just documents our behavior in those cases. + // These values might change at any point and are not considered + // what the implementation should be doing here. + accessor.addRepeatedGroupIterable( + 1, [fakeValue], TestMessage.instanceCreator); + const list = + accessor.getRepeatedGroupIterable(1, TestMessage.instanceCreator); + expect(Array.from(list)).toEqual([null]); + } + }); + + it('fail when adding single group value with wrong type value', () => { + const accessor = Kernel.createEmpty(); + const fakeValue = /** @type {!TestMessage} */ (/** @type {*} */ (null)); + if (CHECK_CRITICAL_STATE) { + expect( + () => accessor.addRepeatedGroupElement( + 1, fakeValue, TestMessage.instanceCreator)) + .toThrowError('Given value is not a message instance: null'); + } else { + // Note in unchecked mode we produce invalid output for invalid inputs. + // This test just documents our behavior in those cases. + // These values might change at any point and are not considered + // what the implementation should be doing here. + accessor.addRepeatedGroupElement( + 1, fakeValue, TestMessage.instanceCreator); + const list = + accessor.getRepeatedGroupIterable(1, TestMessage.instanceCreator); + expect(Array.from(list)).toEqual([null]); + } + }); + + it('fail when setting message values with wrong type value', () => { + const accessor = Kernel.createEmpty(); + const fakeValue = /** @type {!TestMessage} */ (/** @type {*} */ (null)); + if (CHECK_CRITICAL_STATE) { + expect(() => accessor.setRepeatedGroupIterable(1, [fakeValue])) + .toThrowError('Given value is not a message instance: null'); + } else { + // Note in unchecked mode we produce invalid output for invalid inputs. + // This test just documents our behavior in those cases. + // These values might change at any point and are not considered + // what the implementation should be doing here. + accessor.setRepeatedGroupIterable(1, [fakeValue]); + const list = + accessor.getRepeatedGroupIterable(1, TestMessage.instanceCreator); + expect(Array.from(list)).toEqual([null]); + } + }); + + it('fail when setting single value with wrong type value', () => { + const accessor = + Kernel.fromArrayBuffer(createArrayBuffer(0x0B, 0x08, 0x00, 0x0C)); + const fakeValue = /** @type {!TestMessage} */ (/** @type {*} */ (null)); + if (CHECK_CRITICAL_STATE) { + expect( + () => accessor.setRepeatedGroupElement( + /* fieldNumber= */ 1, fakeValue, TestMessage.instanceCreator, + /* index= */ 0)) + .toThrowError('Given value is not a message instance: null'); + } else { + // Note in unchecked mode we produce invalid output for invalid inputs. + // This test just documents our behavior in those cases. + // These values might change at any point and are not considered + // what the implementation should be doing here. + accessor.setRepeatedGroupElement( + /* fieldNumber= */ 1, fakeValue, TestMessage.instanceCreator, + /* index= */ 0); + const list = + accessor.getRepeatedGroupIterable(1, TestMessage.instanceCreator); + expect(Array.from(list).length).toEqual(1); + } + }); + + it('fail when setting single value with out-of-bound index', () => { + const accessor = + Kernel.fromArrayBuffer(createArrayBuffer(0x0B, 0x08, 0x00, 0x0C)); + const msg1 = + accessor.getRepeatedGroupElement(1, TestMessage.instanceCreator, 0); + const bytes2 = createArrayBuffer(0x08, 0x01); + const msg2 = new TestMessage(Kernel.fromArrayBuffer(bytes2)); + if (CHECK_CRITICAL_STATE) { + expect( + () => accessor.setRepeatedGroupElement( + /* fieldNumber= */ 1, msg2, TestMessage.instanceCreator, + /* index= */ 1)) + .toThrowError('Index out of bounds: index: 1 size: 1'); + } else { + // Note in unchecked mode we produce invalid output for invalid inputs. + // This test just documents our behavior in those cases. + // These values might change at any point and are not considered + // what the implementation should be doing here. + accessor.setRepeatedGroupElement( + /* fieldNumber= */ 1, msg2, TestMessage.instanceCreator, + /* index= */ 1); + expectEqualToArray( + accessor.getRepeatedGroupIterable(1, TestMessage.instanceCreator), + [msg1, msg2]); + } + }); +}); diff --git a/js/experimental/runtime/kernel/kernel_test.js b/js/experimental/runtime/kernel/kernel_test.js index 61cdeec657..e72be4f3b6 100644 --- a/js/experimental/runtime/kernel/kernel_test.js +++ b/js/experimental/runtime/kernel/kernel_test.js @@ -2075,3 +2075,255 @@ describe('Double access', () => { } }); }); + +describe('Kernel for singular group does', () => { + it('return group from the input', () => { + const bytes = createArrayBuffer(0x0B, 0x08, 0x01, 0x0C); + const accessor = Kernel.fromArrayBuffer(bytes); + const msg = accessor.getGroupOrNull(1, TestMessage.instanceCreator); + expect(msg.getBoolWithDefault(1, false)).toBe(true); + }); + + it('return group from the input when pivot is set', () => { + const bytes = createArrayBuffer(0x0B, 0x08, 0x01, 0x0C); + const accessor = Kernel.fromArrayBuffer(bytes); + const msg = accessor.getGroupOrNull(1, TestMessage.instanceCreator, 0); + expect(msg.getBoolWithDefault(1, false)).toBe(true); + }); + + it('encode group from the input', () => { + const bytes = createArrayBuffer(0x0B, 0x08, 0x01, 0x0C); + const accessor = Kernel.fromArrayBuffer(bytes); + expect(accessor.serialize()).toEqual(bytes); + }); + + it('encode group from the input after read', () => { + const bytes = createArrayBuffer(0x0B, 0x08, 0x01, 0x0C); + const accessor = Kernel.fromArrayBuffer(bytes); + accessor.getGroupOrNull(1, TestMessage.instanceCreator); + expect(accessor.serialize()).toEqual(bytes); + }); + + it('return last group from multiple inputs', () => { + const bytes = + createArrayBuffer(0x0B, 0x08, 0x00, 0x0C, 0x0B, 0x08, 0x01, 0x0C); + const accessor = Kernel.fromArrayBuffer(bytes); + const msg = accessor.getGroupOrNull(1, TestMessage.instanceCreator); + expect(msg.getBoolWithDefault(1, false)).toBe(true); + }); + + it('removes duplicated group when serializing', () => { + const bytes = + createArrayBuffer(0x0B, 0x08, 0x00, 0x0C, 0x0B, 0x08, 0x01, 0x0C); + const accessor = Kernel.fromArrayBuffer(bytes); + accessor.getGroupOrNull(1, TestMessage.instanceCreator); + expect(accessor.serialize()) + .toEqual(createArrayBuffer(0x0B, 0x08, 0x01, 0x0C)); + }); + + it('encode group from multiple inputs', () => { + const bytes = + createArrayBuffer(0x0B, 0x08, 0x00, 0x0C, 0x0B, 0x08, 0x01, 0x0C); + const accessor = Kernel.fromArrayBuffer(bytes); + expect(accessor.serialize()).toEqual(bytes); + }); + + it('encode group after read', () => { + const bytes = + createArrayBuffer(0x0B, 0x08, 0x00, 0x0C, 0x0B, 0x08, 0x01, 0x0C); + const expected = createArrayBuffer(0x0B, 0x08, 0x01, 0x0C); + const accessor = Kernel.fromArrayBuffer(bytes); + accessor.getGroupOrNull(1, TestMessage.instanceCreator); + expect(accessor.serialize()).toEqual(expected); + }); + + it('return group from setter', () => { + const bytes = createArrayBuffer(0x08, 0x01); + const accessor = Kernel.fromArrayBuffer(new ArrayBuffer(0)); + const subaccessor = Kernel.fromArrayBuffer(bytes); + const submsg1 = new TestMessage(subaccessor); + accessor.setGroup(1, submsg1); + const submsg2 = accessor.getGroup(1, TestMessage.instanceCreator); + expect(submsg1).toBe(submsg2); + }); + + it('encode group from setter', () => { + const accessor = Kernel.fromArrayBuffer(new ArrayBuffer(0)); + const subaccessor = Kernel.fromArrayBuffer(createArrayBuffer(0x08, 0x01)); + const submsg = new TestMessage(subaccessor); + accessor.setGroup(1, submsg); + const expected = createArrayBuffer(0x0B, 0x08, 0x01, 0x0C); + expect(accessor.serialize()).toEqual(expected); + }); + + it('leave hasFieldNumber unchanged after getGroupOrNull', () => { + const accessor = Kernel.createEmpty(); + expect(accessor.hasFieldNumber(1)).toBe(false); + expect(accessor.getGroupOrNull(1, TestMessage.instanceCreator)).toBe(null); + expect(accessor.hasFieldNumber(1)).toBe(false); + }); + + it('serialize changes to subgroups made with getGroupsOrNull', () => { + const intTwoBytes = createArrayBuffer(0x0B, 0x08, 0x02, 0x0C); + const accessor = Kernel.fromArrayBuffer(intTwoBytes); + const mutableSubMessage = + accessor.getGroupOrNull(1, TestMessage.instanceCreator); + mutableSubMessage.setInt32(1, 10); + const intTenBytes = createArrayBuffer(0x0B, 0x08, 0x0A, 0x0C); + expect(accessor.serialize()).toEqual(intTenBytes); + }); + + it('serialize additions to subgroups made with getGroupOrNull', () => { + const intTwoBytes = createArrayBuffer(0x0B, 0x08, 0x02, 0x0C); + const accessor = Kernel.fromArrayBuffer(intTwoBytes); + const mutableSubMessage = + accessor.getGroupOrNull(1, TestMessage.instanceCreator); + mutableSubMessage.setInt32(2, 3); + // Sub group contains the original field, plus the new one. + expect(accessor.serialize()) + .toEqual(createArrayBuffer(0x0B, 0x08, 0x02, 0x10, 0x03, 0x0C)); + }); + + it('fail with getGroupOrNull if immutable group exist in cache', () => { + const intTwoBytes = createArrayBuffer(0x0B, 0x08, 0x02, 0x0C); + const accessor = Kernel.fromArrayBuffer(intTwoBytes); + + const readOnly = accessor.getGroup(1, TestMessage.instanceCreator); + if (CHECK_TYPE) { + expect(() => accessor.getGroupOrNull(1, TestMessage.instanceCreator)) + .toThrow(); + } else { + const mutableSubGropu = + accessor.getGroupOrNull(1, TestMessage.instanceCreator); + // The instance returned by getGroupOrNull is the exact same instance. + expect(mutableSubGropu).toBe(readOnly); + + // Serializing the subgroup does not write the changes + mutableSubGropu.setInt32(1, 0); + expect(accessor.serialize()).toEqual(intTwoBytes); + } + }); + + it('change hasFieldNumber after getGroupAttach', () => { + const accessor = Kernel.createEmpty(); + expect(accessor.hasFieldNumber(1)).toBe(false); + expect(accessor.getGroupAttach(1, TestMessage.instanceCreator)) + .not.toBe(null); + expect(accessor.hasFieldNumber(1)).toBe(true); + }); + + it('change hasFieldNumber after getGroupAttach when pivot is set', () => { + const accessor = Kernel.createEmpty(); + expect(accessor.hasFieldNumber(1)).toBe(false); + expect( + accessor.getGroupAttach(1, TestMessage.instanceCreator, /* pivot= */ 1)) + .not.toBe(null); + expect(accessor.hasFieldNumber(1)).toBe(true); + }); + + it('serialize subgroups made with getGroupAttach', () => { + const accessor = Kernel.createEmpty(); + const mutableSubGroup = + accessor.getGroupAttach(1, TestMessage.instanceCreator); + mutableSubGroup.setInt32(1, 10); + const intTenBytes = createArrayBuffer(0x0B, 0x08, 0x0A, 0x0C); + expect(accessor.serialize()).toEqual(intTenBytes); + }); + + it('serialize additions to subgroups using getMessageAttach', () => { + const intTwoBytes = createArrayBuffer(0x0B, 0x08, 0x02, 0x0C); + const accessor = Kernel.fromArrayBuffer(intTwoBytes); + const mutableSubGroup = + accessor.getGroupAttach(1, TestMessage.instanceCreator); + mutableSubGroup.setInt32(2, 3); + // Sub message contains the original field, plus the new one. + expect(accessor.serialize()) + .toEqual(createArrayBuffer(0x0B, 0x08, 0x02, 0x10, 0x03, 0x0C)); + }); + + it('fail with getGroupAttach if immutable message exist in cache', () => { + const intTwoBytes = createArrayBuffer(0x0B, 0x08, 0x02, 0x0C); + const accessor = Kernel.fromArrayBuffer(intTwoBytes); + + const readOnly = accessor.getGroup(1, TestMessage.instanceCreator); + if (CHECK_TYPE) { + expect(() => accessor.getGroupAttach(1, TestMessage.instanceCreator)) + .toThrow(); + } else { + const mutableSubGroup = + accessor.getGroupAttach(1, TestMessage.instanceCreator); + // The instance returned by getMessageOrNull is the exact same instance. + expect(mutableSubGroup).toBe(readOnly); + + // Serializing the submessage does not write the changes + mutableSubGroup.setInt32(1, 0); + expect(accessor.serialize()).toEqual(intTwoBytes); + } + }); + + it('read default group return empty group with getGroup', () => { + const bytes = new ArrayBuffer(0); + const accessor = Kernel.fromArrayBuffer(bytes); + expect(accessor.getGroup(1, TestMessage.instanceCreator)).toBeTruthy(); + expect(accessor.getGroup(1, TestMessage.instanceCreator).serialize()) + .toEqual(bytes); + }); + + it('read default group return null with getGroupOrNull', () => { + const bytes = new ArrayBuffer(0); + const accessor = Kernel.fromArrayBuffer(bytes); + expect(accessor.getGroupOrNull(1, TestMessage.instanceCreator)).toBe(null); + }); + + it('read group preserve reference equality', () => { + const bytes = createArrayBuffer(0x0B, 0x08, 0x02, 0x0C); + const accessor = Kernel.fromArrayBuffer(bytes); + const msg1 = accessor.getGroupOrNull(1, TestMessage.instanceCreator); + const msg2 = accessor.getGroupOrNull(1, TestMessage.instanceCreator); + const msg3 = accessor.getGroupAttach(1, TestMessage.instanceCreator); + expect(msg1).toBe(msg2); + expect(msg1).toBe(msg3); + }); + + it('fail when getting group with null instance constructor', () => { + const accessor = + Kernel.fromArrayBuffer(createArrayBuffer(0x0A, 0x02, 0x08, 0x01)); + const nullMessage = /** @type {function(!Kernel):!TestMessage} */ + (/** @type {*} */ (null)); + expect(() => accessor.getGroupOrNull(1, nullMessage)).toThrow(); + }); + + it('fail when setting group value with null value', () => { + const accessor = Kernel.fromArrayBuffer(new ArrayBuffer(0)); + const fakeMessage = /** @type {!TestMessage} */ (/** @type {*} */ (null)); + if (CHECK_CRITICAL_TYPE) { + expect(() => accessor.setGroup(1, fakeMessage)) + .toThrowError('Given value is not a message instance: null'); + } else { + // Note in unchecked mode we produce invalid output for invalid inputs. + // This test just documents our behavior in those cases. + // These values might change at any point and are not considered + // what the implementation should be doing here. + accessor.setMessage(1, fakeMessage); + expect(accessor.getGroupOrNull( + /* fieldNumber= */ 1, TestMessage.instanceCreator)) + .toBeNull(); + } + }); + + it('reads group in a longer buffer', () => { + const bytes = createArrayBuffer( + 0x12, 0x20, // 32 length delimited + 0x00, // random values for padding start + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, // random values for padding end + 0x0B, // Group tag + 0x08, 0x02, 0x0C); + const accessor = Kernel.fromArrayBuffer(bytes); + const msg1 = accessor.getGroupOrNull(1, TestMessage.instanceCreator); + const msg2 = accessor.getGroupOrNull(1, TestMessage.instanceCreator); + expect(msg1).toBe(msg2); + }); +}); diff --git a/js/experimental/runtime/kernel/message_set.js b/js/experimental/runtime/kernel/message_set.js new file mode 100644 index 0000000000..d66bace7b5 --- /dev/null +++ b/js/experimental/runtime/kernel/message_set.js @@ -0,0 +1,285 @@ +/* +########################################################## +# # +# __ __ _____ _ _ _____ _ _ _____ # +# \ \ / /\ | __ \| \ | |_ _| \ | |/ ____| # +# \ \ /\ / / \ | |__) | \| | | | | \| | | __ # +# \ \/ \/ / /\ \ | _ /| . ` | | | | . ` | | |_ | # +# \ /\ / ____ \| | \ \| |\ |_| |_| |\ | |__| | # +# \/ \/_/ \_\_| \_\_| \_|_____|_| \_|\_____| # +# # +# # +########################################################## +# Do not use this class in your code. This class purely # +# exists to make proto code generation easier. # +########################################################## +*/ +goog.module('protobuf.runtime.MessageSet'); + +const InternalMessage = goog.require('protobuf.binary.InternalMessage'); +const Kernel = goog.require('protobuf.runtime.Kernel'); + +// These are the tags for the old MessageSet format, which was defined as: +// message MessageSet { +// repeated group Item = 1 { +// required uint32 type_id = 2; +// optional bytes message = 3; +// } +// } +/** @const {number} */ +const MSET_GROUP_FIELD_NUMBER = 1; +/** @const {number} */ +const MSET_TYPE_ID_FIELD_NUMBER = 2; +/** @const {number} */ +const MSET_MESSAGE_FIELD_NUMBER = 3; + +/** + * @param {!Kernel} kernel + * @return {!Map} + */ +function createItemMap(kernel) { + const itemMap = new Map(); + let totalCount = 0; + for (const item of kernel.getRepeatedGroupIterable( + MSET_GROUP_FIELD_NUMBER, Item.fromKernel)) { + itemMap.set(item.getTypeId(), item); + totalCount++; + } + + // Normalize the entries. + if (totalCount > itemMap.size) { + writeItemMap(kernel, itemMap); + } + return itemMap; +} + +/** + * @param {!Kernel} kernel + * @param {!Map} itemMap + */ +function writeItemMap(kernel, itemMap) { + kernel.setRepeatedGroupIterable(MSET_GROUP_FIELD_NUMBER, itemMap.values()); +} + +/** + * @implements {InternalMessage} + * @final + */ +class MessageSet { + /** + * @param {!Kernel} kernel + * @return {!MessageSet} + */ + static fromKernel(kernel) { + const itemMap = createItemMap(kernel); + return new MessageSet(kernel, itemMap); + } + + /** + * @return {!MessageSet} + */ + static createEmpty() { + return MessageSet.fromKernel(Kernel.createEmpty()); + } + + /** + * @param {!Kernel} kernel + * @param {!Map} itemMap + * @private + */ + constructor(kernel, itemMap) { + /** @const {!Kernel} @private */ + this.kernel_ = kernel; + /** @const {!Map} @private */ + this.itemMap_ = itemMap; + } + + + + // code helpers for code gen + + /** + * @param {number} typeId + * @param {function(!Kernel):T} instanceCreator + * @param {number=} pivot + * @return {?T} + * @template T + */ + getMessageOrNull(typeId, instanceCreator, pivot) { + const item = this.itemMap_.get(typeId); + return item ? item.getMessageOrNull(instanceCreator, pivot) : null; + } + + /** + * @param {number} typeId + * @param {function(!Kernel):T} instanceCreator + * @param {number=} pivot + * @return {T} + * @template T + */ + getMessageAttach(typeId, instanceCreator, pivot) { + let item = this.itemMap_.get(typeId); + if (item) { + return item.getMessageAttach(instanceCreator, pivot); + } + const message = instanceCreator(Kernel.createEmpty()); + this.setMessage(typeId, message); + return message; + } + + /** + * @param {number} typeId + * @param {number=} pivot + * @return {?Kernel} + */ + getMessageAccessorOrNull(typeId, pivot) { + const item = this.itemMap_.get(typeId); + return item ? item.getMessageAccessorOrNull(pivot) : null; + } + + + /** + * @param {number} typeId + */ + clearMessage(typeId) { + if (this.itemMap_.delete(typeId)) { + writeItemMap(this.kernel_, this.itemMap_); + } + } + + /** + * @param {number} typeId + * @return {boolean} + */ + hasMessage(typeId) { + return this.itemMap_.has(typeId); + } + + /** + * @param {number} typeId + * @param {!InternalMessage} value + */ + setMessage(typeId, value) { + const item = this.itemMap_.get(typeId); + if (item) { + item.setMessage(value); + } else { + this.itemMap_.set(typeId, Item.create(typeId, value)); + writeItemMap(this.kernel_, this.itemMap_); + } + } + + /** + * @return {!Kernel} + * @override + */ + internalGetKernel() { + return this.kernel_; + } +} + +/** + * @implements {InternalMessage} + * @final + */ +class Item { + /** + * @param {number} typeId + * @param {!InternalMessage} message + * @return {!Item} + */ + static create(typeId, message) { + const messageSet = Item.fromKernel(Kernel.createEmpty()); + messageSet.setTypeId_(typeId); + messageSet.setMessage(message); + return messageSet; + } + + + /** + * @param {!Kernel} kernel + * @return {!Item} + */ + static fromKernel(kernel) { + return new Item(kernel); + } + + /** + * @param {!Kernel} kernel + * @private + */ + constructor(kernel) { + /** @const {!Kernel} @private */ + this.kernel_ = kernel; + } + + /** + * @param {function(!Kernel):T} instanceCreator + * @param {number=} pivot + * @return {T} + * @template T + */ + getMessage(instanceCreator, pivot) { + return this.kernel_.getMessage( + MSET_MESSAGE_FIELD_NUMBER, instanceCreator, pivot); + } + + /** + * @param {function(!Kernel):T} instanceCreator + * @param {number=} pivot + * @return {?T} + * @template T + */ + getMessageOrNull(instanceCreator, pivot) { + return this.kernel_.getMessageOrNull( + MSET_MESSAGE_FIELD_NUMBER, instanceCreator, pivot); + } + + /** + * @param {function(!Kernel):T} instanceCreator + * @param {number=} pivot + * @return {T} + * @template T + */ + getMessageAttach(instanceCreator, pivot) { + return this.kernel_.getMessageAttach( + MSET_MESSAGE_FIELD_NUMBER, instanceCreator, pivot); + } + + /** + * @param {number=} pivot + * @return {?Kernel} + */ + getMessageAccessorOrNull(pivot) { + return this.kernel_.getMessageAccessorOrNull( + MSET_MESSAGE_FIELD_NUMBER, pivot); + } + + /** @param {!InternalMessage} value */ + setMessage(value) { + this.kernel_.setMessage(MSET_MESSAGE_FIELD_NUMBER, value); + } + + /** @return {number} */ + getTypeId() { + return this.kernel_.getUint32WithDefault(MSET_TYPE_ID_FIELD_NUMBER); + } + + /** + * @param {number} value + * @private + */ + setTypeId_(value) { + this.kernel_.setUint32(MSET_TYPE_ID_FIELD_NUMBER, value); + } + + /** + * @return {!Kernel} + * @override + */ + internalGetKernel() { + return this.kernel_; + } +} + +exports = MessageSet; diff --git a/js/experimental/runtime/kernel/message_set_test.js b/js/experimental/runtime/kernel/message_set_test.js new file mode 100644 index 0000000000..35e5935015 --- /dev/null +++ b/js/experimental/runtime/kernel/message_set_test.js @@ -0,0 +1,262 @@ +/** + * @fileoverview Tests for message_set.js. + */ +goog.module('protobuf.runtime.MessageSetTest'); + +goog.setTestOnly(); + +const Kernel = goog.require('protobuf.runtime.Kernel'); +const MessageSet = goog.require('protobuf.runtime.MessageSet'); +const TestMessage = goog.require('protobuf.testing.binary.TestMessage'); + +/** + * @param {...number} bytes + * @return {!ArrayBuffer} + */ +function createArrayBuffer(...bytes) { + return new Uint8Array(bytes).buffer; +} + +describe('MessageSet does', () => { + it('returns no messages for empty set', () => { + const messageSet = MessageSet.createEmpty(); + expect(messageSet.getMessageOrNull(12345, TestMessage.instanceCreator)) + .toBeNull(); + }); + + it('returns no kernel for empty set', () => { + const messageSet = MessageSet.createEmpty(); + expect(messageSet.getMessageAccessorOrNull(12345)).toBeNull(); + }); + + it('returns message that has been set', () => { + const messageSet = MessageSet.createEmpty(); + const message = TestMessage.createEmpty(); + messageSet.setMessage(12345, message); + expect(messageSet.getMessageOrNull(12345, TestMessage.instanceCreator)) + .toBe(message); + }); + + it('returns null for cleared message', () => { + const messageSet = MessageSet.createEmpty(); + const message = TestMessage.createEmpty(); + messageSet.setMessage(12345, message); + messageSet.clearMessage(12345); + expect(messageSet.getMessageAccessorOrNull(12345)).toBeNull(); + }); + + it('returns false for not present message', () => { + const messageSet = MessageSet.createEmpty(); + expect(messageSet.hasMessage(12345)).toBe(false); + }); + + it('returns true for present message', () => { + const messageSet = MessageSet.createEmpty(); + const message = TestMessage.createEmpty(); + messageSet.setMessage(12345, message); + expect(messageSet.hasMessage(12345)).toBe(true); + }); + + it('returns false for cleared message', () => { + const messageSet = MessageSet.createEmpty(); + const message = TestMessage.createEmpty(); + messageSet.setMessage(12345, message); + messageSet.clearMessage(12345); + expect(messageSet.hasMessage(12345)).toBe(false); + }); + + it('returns false for cleared message without it being present', () => { + const messageSet = MessageSet.createEmpty(); + messageSet.clearMessage(12345); + expect(messageSet.hasMessage(12345)).toBe(false); + }); + + const createMessageSet = () => { + const messageSet = MessageSet.createEmpty(); + const message = TestMessage.createEmpty(); + message.setInt32(1, 2); + messageSet.setMessage(12345, message); + + + const parsedKernel = + Kernel.fromArrayBuffer(messageSet.internalGetKernel().serialize()); + return MessageSet.fromKernel(parsedKernel); + }; + + it('pass through pivot for getMessageOrNull', () => { + const messageSet = createMessageSet(); + const message = + messageSet.getMessageOrNull(12345, TestMessage.instanceCreator, 2); + expect(message.internalGetKernel().getPivot()).toBe(2); + }); + + it('pass through pivot for getMessageAttach', () => { + const messageSet = createMessageSet(); + const message = + messageSet.getMessageAttach(12345, TestMessage.instanceCreator, 2); + expect(message.internalGetKernel().getPivot()).toBe(2); + }); + + it('pass through pivot for getMessageAccessorOrNull', () => { + const messageSet = createMessageSet(); + const kernel = messageSet.getMessageAccessorOrNull(12345, 2); + expect(kernel.getPivot()).toBe(2); + }); + + it('pick the last value in the stream', () => { + const arrayBuffer = createArrayBuffer( + 0x52, // Tag (field:10, length delimited) + 0x14, // Length of 20 bytes + 0x0B, // Start group fieldnumber 1 + 0x10, // Tag (field 2, varint) + 0xB9, // 12345 + 0x60, // 12345 + 0x1A, // Tag (field 3, length delimited) + 0x03, // length 3 + 0xA0, // Tag (fieldnumber 20, varint) + 0x01, // Tag (fieldnumber 20, varint) + 0x1E, // 30 + 0x0C, // Stop Group field number 1 + // second group + 0x0B, // Start group fieldnumber 1 + 0x10, // Tag (field 2, varint) + 0xB9, // 12345 + 0x60, // 12345 + 0x1A, // Tag (field 3, length delimited) + 0x03, // length 3 + 0xA0, // Tag (fieldnumber 20, varint) + 0x01, // Tag (fieldnumber 20, varint) + 0x01, // 1 + 0x0C // Stop Group field number 1 + ); + + const outerMessage = Kernel.fromArrayBuffer(arrayBuffer); + + const messageSet = outerMessage.getMessage(10, MessageSet.fromKernel); + + const message = + messageSet.getMessageOrNull(12345, TestMessage.instanceCreator); + expect(message.getInt32WithDefault(20)).toBe(1); + }); + + it('removes duplicates when read', () => { + const arrayBuffer = createArrayBuffer( + 0x52, // Tag (field:10, length delimited) + 0x14, // Length of 20 bytes + 0x0B, // Start group fieldnumber 1 + 0x10, // Tag (field 2, varint) + 0xB9, // 12345 + 0x60, // 12345 + 0x1A, // Tag (field 3, length delimited) + 0x03, // length 3 + 0xA0, // Tag (fieldnumber 20, varint) + 0x01, // Tag (fieldnumber 20, varint) + 0x1E, // 30 + 0x0C, // Stop Group field number 1 + // second group + 0x0B, // Start group fieldnumber 1 + 0x10, // Tag (field 2, varint) + 0xB9, // 12345 + 0x60, // 12345 + 0x1A, // Tag (field 3, length delimited) + 0x03, // length 3 + 0xA0, // Tag (fieldnumber 20, varint) + 0x01, // Tag (fieldnumber 20, varint) + 0x01, // 1 + 0x0C // Stop Group field number 1 + ); + + + const outerMessage = Kernel.fromArrayBuffer(arrayBuffer); + outerMessage.getMessageAttach(10, MessageSet.fromKernel); + + expect(outerMessage.serialize()) + .toEqual(createArrayBuffer( + 0x52, // Tag (field:10, length delimited) + 0x0A, // Length of 10 bytes + 0x0B, // Start group fieldnumber 1 + 0x10, // Tag (field 2, varint) + 0xB9, // 12345 + 0x60, // 12345 + 0x1A, // Tag (field 3, length delimited) + 0x03, // length 3 + 0xA0, // Tag (fieldnumber 20, varint) + 0x01, // Tag (fieldnumber 20, varint) + 0x01, // 1 + 0x0C // Stop Group field number 1 + )); + }); + + it('allow for large typeIds', () => { + const messageSet = MessageSet.createEmpty(); + const message = TestMessage.createEmpty(); + messageSet.setMessage(0xFFFFFFFE >>> 0, message); + expect(messageSet.hasMessage(0xFFFFFFFE >>> 0)).toBe(true); + }); +}); + +describe('Optional MessageSet does', () => { + // message Bar { + // optional MessageSet mset = 10; + //} + // + // message Foo { + // extend proto2.bridge.MessageSet { + // optional Foo message_set_extension = 12345; + // } + // optional int32 f20 = 20; + //} + + it('encode as a field', () => { + const fooMessage = Kernel.createEmpty(); + fooMessage.setInt32(20, 30); + + const messageSet = MessageSet.createEmpty(); + messageSet.setMessage(12345, TestMessage.instanceCreator(fooMessage)); + + const barMessage = Kernel.createEmpty(); + barMessage.setMessage(10, messageSet); + + expect(barMessage.serialize()) + .toEqual(createArrayBuffer( + 0x52, // Tag (field:10, length delimited) + 0x0A, // Length of 10 bytes + 0x0B, // Start group fieldnumber 1 + 0x10, // Tag (field 2, varint) + 0xB9, // 12345 + 0x60, // 12345 + 0x1A, // Tag (field 3, length delimited) + 0x03, // length 3 + 0xA0, // Tag (fieldnumber 20, varint) + 0x01, // Tag (fieldnumber 20, varint) + 0x1E, // 30 + 0x0C // Stop Group field number 1 + )); + }); + + it('deserializes', () => { + const fooMessage = Kernel.createEmpty(); + fooMessage.setInt32(20, 30); + + const messageSet = MessageSet.createEmpty(); + messageSet.setMessage(12345, TestMessage.instanceCreator(fooMessage)); + + + const barMessage = Kernel.createEmpty(); + barMessage.setMessage(10, messageSet); + + const arrayBuffer = barMessage.serialize(); + + const barMessageParsed = Kernel.fromArrayBuffer(arrayBuffer); + expect(barMessageParsed.hasFieldNumber(10)).toBe(true); + + const messageSetParsed = + barMessageParsed.getMessage(10, MessageSet.fromKernel); + + const fooMessageParsed = + messageSetParsed.getMessageOrNull(12345, TestMessage.instanceCreator) + .internalGetKernel(); + + expect(fooMessageParsed.getInt32WithDefault(20)).toBe(30); + }); +}); diff --git a/js/experimental/runtime/kernel/tag.js b/js/experimental/runtime/kernel/tag.js new file mode 100644 index 0000000000..b288df3b8e --- /dev/null +++ b/js/experimental/runtime/kernel/tag.js @@ -0,0 +1,144 @@ +goog.module('protobuf.binary.tag'); + +const BufferDecoder = goog.require('protobuf.binary.BufferDecoder'); +const WireType = goog.require('protobuf.binary.WireType'); +const {checkCriticalElementIndex, checkCriticalState} = goog.require('protobuf.internal.checks'); + +/** + * Returns wire type stored in a tag. + * Protos store the wire type as the first 3 bit of a tag. + * @param {number} tag + * @return {!WireType} + */ +function tagToWireType(tag) { + return /** @type {!WireType} */ (tag & 0x07); +} + +/** + * Returns the field number stored in a tag. + * Protos store the field number in the upper 29 bits of a 32 bit number. + * @param {number} tag + * @return {number} + */ +function tagToFieldNumber(tag) { + return tag >>> 3; +} + +/** + * Combines wireType and fieldNumber into a tag. + * @param {!WireType} wireType + * @param {number} fieldNumber + * @return {number} + */ +function createTag(wireType, fieldNumber) { + return (fieldNumber << 3 | wireType) >>> 0; +} + +/** + * Returns the length, in bytes, of the field in the tag stream, less the tag + * itself. + * Note: This moves the cursor in the bufferDecoder. + * @param {!BufferDecoder} bufferDecoder + * @param {number} start + * @param {!WireType} wireType + * @param {number} fieldNumber + * @return {number} + * @private + */ +function getTagLength(bufferDecoder, start, wireType, fieldNumber) { + bufferDecoder.setCursor(start); + skipField(bufferDecoder, wireType, fieldNumber); + return bufferDecoder.cursor() - start; +} + +/** + * @param {number} value + * @return {number} + */ +function get32BitVarintLength(value) { + if (value < 0) { + return 5; + } + let size = 1; + while (value >= 128) { + size++; + value >>>= 7; + } + return size; +} + +/** + * Skips over a field. + * Note: If the field is a start group the entire group will be skipped, placing + * the cursor onto the next field. + * @param {!BufferDecoder} bufferDecoder + * @param {!WireType} wireType + * @param {number} fieldNumber + */ +function skipField(bufferDecoder, wireType, fieldNumber) { + switch (wireType) { + case WireType.VARINT: + checkCriticalElementIndex( + bufferDecoder.cursor(), bufferDecoder.endIndex()); + bufferDecoder.skipVarint(); + return; + case WireType.FIXED64: + bufferDecoder.skip(8); + return; + case WireType.DELIMITED: + checkCriticalElementIndex( + bufferDecoder.cursor(), bufferDecoder.endIndex()); + const length = bufferDecoder.getUnsignedVarint32(); + bufferDecoder.skip(length); + return; + case WireType.START_GROUP: + const foundGroup = skipGroup_(bufferDecoder, fieldNumber); + checkCriticalState(foundGroup, 'No end group found.'); + return; + case WireType.FIXED32: + bufferDecoder.skip(4); + return; + default: + throw new Error(`Unexpected wire type: ${wireType}`); + } +} + +/** + * Skips over fields until it finds the end of a given group consuming the stop + * group tag. + * @param {!BufferDecoder} bufferDecoder + * @param {number} groupFieldNumber + * @return {boolean} Whether the end group tag was found. + * @private + */ +function skipGroup_(bufferDecoder, groupFieldNumber) { + // On a start group we need to keep skipping fields until we find a + // corresponding stop group + // Note: Since we are calling skipField from here nested groups will be + // handled by recursion of this method and thus we will not see a nested + // STOP GROUP here unless there is something wrong with the input data. + while (bufferDecoder.hasNext()) { + const tag = bufferDecoder.getUnsignedVarint32(); + const wireType = tagToWireType(tag); + const fieldNumber = tagToFieldNumber(tag); + + if (wireType === WireType.END_GROUP) { + checkCriticalState( + groupFieldNumber === fieldNumber, + `Expected stop group for fieldnumber ${groupFieldNumber} not found.`); + return true; + } else { + skipField(bufferDecoder, wireType, fieldNumber); + } + } + return false; +} + +exports = { + createTag, + get32BitVarintLength, + getTagLength, + skipField, + tagToWireType, + tagToFieldNumber, +}; diff --git a/js/experimental/runtime/kernel/tag_test.js b/js/experimental/runtime/kernel/tag_test.js new file mode 100644 index 0000000000..04a6cb6668 --- /dev/null +++ b/js/experimental/runtime/kernel/tag_test.js @@ -0,0 +1,221 @@ +/** + * @fileoverview Tests for tag.js. + */ +goog.module('protobuf.binary.TagTests'); + +const BufferDecoder = goog.require('protobuf.binary.BufferDecoder'); +const WireType = goog.require('protobuf.binary.WireType'); +const {CHECK_CRITICAL_STATE} = goog.require('protobuf.internal.checks'); +const {createTag, get32BitVarintLength, skipField, tagToFieldNumber, tagToWireType} = goog.require('protobuf.binary.tag'); + + +goog.setTestOnly(); + +/** + * @param {...number} bytes + * @return {!ArrayBuffer} + */ +function createArrayBuffer(...bytes) { + return new Uint8Array(bytes).buffer; +} + +describe('skipField', () => { + it('skips varints', () => { + const bufferDecoder = + BufferDecoder.fromArrayBuffer(createArrayBuffer(0x80, 0x00)); + skipField(bufferDecoder, WireType.VARINT, 1); + expect(bufferDecoder.cursor()).toBe(2); + }); + + it('throws for out of bounds varints', () => { + const bufferDecoder = + BufferDecoder.fromArrayBuffer(createArrayBuffer(0x80, 0x00)); + bufferDecoder.setCursor(2); + if (CHECK_CRITICAL_STATE) { + expect(() => skipField(bufferDecoder, WireType.VARINT, 1)).toThrowError(); + } + }); + + it('skips fixed64', () => { + const bufferDecoder = BufferDecoder.fromArrayBuffer( + createArrayBuffer(0x80, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00)); + skipField(bufferDecoder, WireType.FIXED64, 1); + expect(bufferDecoder.cursor()).toBe(8); + }); + + it('throws for fixed64 if length is too short', () => { + const bufferDecoder = + BufferDecoder.fromArrayBuffer(createArrayBuffer(0x80, 0x00)); + if (CHECK_CRITICAL_STATE) { + expect(() => skipField(bufferDecoder, WireType.FIXED64, 1)) + .toThrowError(); + } + }); + + it('skips fixed32', () => { + const bufferDecoder = BufferDecoder.fromArrayBuffer( + createArrayBuffer(0x80, 0x00, 0x00, 0x00)); + skipField(bufferDecoder, WireType.FIXED32, 1); + expect(bufferDecoder.cursor()).toBe(4); + }); + + it('throws for fixed32 if length is too short', () => { + const bufferDecoder = + BufferDecoder.fromArrayBuffer(createArrayBuffer(0x80, 0x00)); + if (CHECK_CRITICAL_STATE) { + expect(() => skipField(bufferDecoder, WireType.FIXED32, 1)) + .toThrowError(); + } + }); + + + it('skips length delimited', () => { + const bufferDecoder = BufferDecoder.fromArrayBuffer( + createArrayBuffer(0x03, 0x00, 0x00, 0x00)); + skipField(bufferDecoder, WireType.DELIMITED, 1); + expect(bufferDecoder.cursor()).toBe(4); + }); + + it('throws for length delimited if length is too short', () => { + const bufferDecoder = + BufferDecoder.fromArrayBuffer(createArrayBuffer(0x03, 0x00, 0x00)); + if (CHECK_CRITICAL_STATE) { + expect(() => skipField(bufferDecoder, WireType.DELIMITED, 1)) + .toThrowError(); + } + }); + + it('skips groups', () => { + const bufferDecoder = BufferDecoder.fromArrayBuffer( + createArrayBuffer(0x0B, 0x08, 0x01, 0x0C)); + bufferDecoder.setCursor(1); + skipField(bufferDecoder, WireType.START_GROUP, 1); + expect(bufferDecoder.cursor()).toBe(4); + }); + + it('skips group in group', () => { + const buffer = createArrayBuffer( + 0x0B, // start outter + 0x10, 0x01, // field: 2, value: 1 + 0x0B, // start inner group + 0x10, 0x01, // payload inner group + 0x0C, // stop inner group + 0x0C // end outter + ); + const bufferDecoder = BufferDecoder.fromArrayBuffer(buffer); + bufferDecoder.setCursor(1); + skipField(bufferDecoder, WireType.START_GROUP, 1); + expect(bufferDecoder.cursor()).toBe(8); + }); + + it('throws for group if length is too short', () => { + // no closing group + const bufferDecoder = + BufferDecoder.fromArrayBuffer(createArrayBuffer(0x0B, 0x00, 0x00)); + if (CHECK_CRITICAL_STATE) { + expect(() => skipField(bufferDecoder, WireType.START_GROUP, 1)) + .toThrowError(); + } + }); +}); + + +describe('tagToWireType', () => { + it('decodes numbers ', () => { + // simple numbers + expect(tagToWireType(0x00)).toBe(WireType.VARINT); + expect(tagToWireType(0x01)).toBe(WireType.FIXED64); + expect(tagToWireType(0x02)).toBe(WireType.DELIMITED); + expect(tagToWireType(0x03)).toBe(WireType.START_GROUP); + expect(tagToWireType(0x04)).toBe(WireType.END_GROUP); + expect(tagToWireType(0x05)).toBe(WireType.FIXED32); + + // upper bits should not matter + expect(tagToWireType(0x08)).toBe(WireType.VARINT); + expect(tagToWireType(0x09)).toBe(WireType.FIXED64); + expect(tagToWireType(0x0A)).toBe(WireType.DELIMITED); + expect(tagToWireType(0x0B)).toBe(WireType.START_GROUP); + expect(tagToWireType(0x0C)).toBe(WireType.END_GROUP); + expect(tagToWireType(0x0D)).toBe(WireType.FIXED32); + + // upper bits should not matter + expect(tagToWireType(0xF8)).toBe(WireType.VARINT); + expect(tagToWireType(0xF9)).toBe(WireType.FIXED64); + expect(tagToWireType(0xFA)).toBe(WireType.DELIMITED); + expect(tagToWireType(0xFB)).toBe(WireType.START_GROUP); + expect(tagToWireType(0xFC)).toBe(WireType.END_GROUP); + expect(tagToWireType(0xFD)).toBe(WireType.FIXED32); + + // negative numbers work + expect(tagToWireType(-8)).toBe(WireType.VARINT); + expect(tagToWireType(-7)).toBe(WireType.FIXED64); + expect(tagToWireType(-6)).toBe(WireType.DELIMITED); + expect(tagToWireType(-5)).toBe(WireType.START_GROUP); + expect(tagToWireType(-4)).toBe(WireType.END_GROUP); + expect(tagToWireType(-3)).toBe(WireType.FIXED32); + }); +}); + +describe('tagToFieldNumber', () => { + it('returns fieldNumber', () => { + expect(tagToFieldNumber(0x08)).toBe(1); + expect(tagToFieldNumber(0x09)).toBe(1); + expect(tagToFieldNumber(0x10)).toBe(2); + expect(tagToFieldNumber(0x12)).toBe(2); + }); +}); + +describe('createTag', () => { + it('combines fieldNumber and wireType', () => { + expect(createTag(WireType.VARINT, 1)).toBe(0x08); + expect(createTag(WireType.FIXED64, 1)).toBe(0x09); + expect(createTag(WireType.DELIMITED, 1)).toBe(0x0A); + expect(createTag(WireType.START_GROUP, 1)).toBe(0x0B); + expect(createTag(WireType.END_GROUP, 1)).toBe(0x0C); + expect(createTag(WireType.FIXED32, 1)).toBe(0x0D); + + expect(createTag(WireType.VARINT, 2)).toBe(0x10); + expect(createTag(WireType.FIXED64, 2)).toBe(0x11); + expect(createTag(WireType.DELIMITED, 2)).toBe(0x12); + expect(createTag(WireType.START_GROUP, 2)).toBe(0x13); + expect(createTag(WireType.END_GROUP, 2)).toBe(0x14); + expect(createTag(WireType.FIXED32, 2)).toBe(0x15); + + expect(createTag(WireType.VARINT, 0x1FFFFFFF)).toBe(0xFFFFFFF8 >>> 0); + expect(createTag(WireType.FIXED64, 0x1FFFFFFF)).toBe(0xFFFFFFF9 >>> 0); + expect(createTag(WireType.DELIMITED, 0x1FFFFFFF)).toBe(0xFFFFFFFA >>> 0); + expect(createTag(WireType.START_GROUP, 0x1FFFFFFF)).toBe(0xFFFFFFFB >>> 0); + expect(createTag(WireType.END_GROUP, 0x1FFFFFFF)).toBe(0xFFFFFFFC >>> 0); + expect(createTag(WireType.FIXED32, 0x1FFFFFFF)).toBe(0xFFFFFFFD >>> 0); + }); +}); + +describe('get32BitVarintLength', () => { + it('length of tag', () => { + expect(get32BitVarintLength(0)).toBe(1); + expect(get32BitVarintLength(1)).toBe(1); + expect(get32BitVarintLength(1)).toBe(1); + + expect(get32BitVarintLength(Math.pow(2, 7) - 1)).toBe(1); + expect(get32BitVarintLength(Math.pow(2, 7))).toBe(2); + + expect(get32BitVarintLength(Math.pow(2, 14) - 1)).toBe(2); + expect(get32BitVarintLength(Math.pow(2, 14))).toBe(3); + + expect(get32BitVarintLength(Math.pow(2, 21) - 1)).toBe(3); + expect(get32BitVarintLength(Math.pow(2, 21))).toBe(4); + + expect(get32BitVarintLength(Math.pow(2, 28) - 1)).toBe(4); + expect(get32BitVarintLength(Math.pow(2, 28))).toBe(5); + + expect(get32BitVarintLength(Math.pow(2, 31) - 1)).toBe(5); + + expect(get32BitVarintLength(-1)).toBe(5); + expect(get32BitVarintLength(-Math.pow(2, 31))).toBe(5); + + expect(get32BitVarintLength(createTag(WireType.VARINT, 0x1fffffff))) + .toBe(5); + expect(get32BitVarintLength(createTag(WireType.FIXED32, 0x1fffffff))) + .toBe(5); + }); +}); diff --git a/js/experimental/runtime/kernel/writer.js b/js/experimental/runtime/kernel/writer.js index 8af7a06d0d..5b8b79b6c3 100644 --- a/js/experimental/runtime/kernel/writer.js +++ b/js/experimental/runtime/kernel/writer.js @@ -10,6 +10,7 @@ const Int64 = goog.require('protobuf.Int64'); const WireType = goog.require('protobuf.binary.WireType'); const {POLYFILL_TEXT_ENCODING, checkFieldNumber, checkTypeUnsignedInt32, checkWireType} = goog.require('protobuf.internal.checks'); const {concatenateByteArrays} = goog.require('protobuf.binary.uint8arrays'); +const {createTag, getTagLength} = goog.require('protobuf.binary.tag'); const {encode} = goog.require('protobuf.binary.textencoding'); /** @@ -92,8 +93,8 @@ class Writer { writeTag(fieldNumber, wireType) { checkFieldNumber(fieldNumber); checkWireType(wireType); - const tag = fieldNumber << 3 | wireType; - this.writeUnsignedVarint32_(tag >>> 0); + const tag = createTag(wireType, fieldNumber); + this.writeUnsignedVarint32_(tag); } /** @@ -299,6 +300,22 @@ class Writer { this.writeSfixed64Value_(value); } + /** + * Writes a sfixed64 value field to the buffer. + * @param {number} fieldNumber + */ + writeStartGroup(fieldNumber) { + this.writeTag(fieldNumber, WireType.START_GROUP); + } + + /** + * Writes a sfixed64 value field to the buffer. + * @param {number} fieldNumber + */ + writeEndGroup(fieldNumber) { + this.writeTag(fieldNumber, WireType.END_GROUP); + } + /** * Writes a uint32 value field to the buffer as a varint without tag. * @param {number} value @@ -430,67 +447,17 @@ class Writer { * @param {!BufferDecoder} bufferDecoder * @param {number} start * @param {!WireType} wireType + * @param {number} fieldNumber * @package */ - writeBufferDecoder(bufferDecoder, start, wireType) { + writeBufferDecoder(bufferDecoder, start, wireType, fieldNumber) { this.closeAndStartNewBuffer_(); - const dataLength = this.getLength_(bufferDecoder, start, wireType); + const dataLength = + getTagLength(bufferDecoder, start, wireType, fieldNumber); this.blocks_.push( bufferDecoder.subBufferDecoder(start, dataLength).asUint8Array()); } - /** - * Returns the length of the data to serialize. Returns -1 when a STOP GROUP - * is found. - * @param {!BufferDecoder} bufferDecoder - * @param {number} start - * @param {!WireType} wireType - * @return {number} - * @private - */ - getLength_(bufferDecoder, start, wireType) { - switch (wireType) { - case WireType.VARINT: - bufferDecoder.setCursor(start); - bufferDecoder.skipVarint(); - return bufferDecoder.cursor() - start; - case WireType.FIXED64: - return 8; - case WireType.DELIMITED: - const dataLength = bufferDecoder.getUnsignedVarint32At(start); - return dataLength + bufferDecoder.cursor() - start; - case WireType.START_GROUP: - return this.getGroupLength_(bufferDecoder, start); - case WireType.FIXED32: - return 4; - default: - throw new Error(`Invalid wire type: ${wireType}`); - } - } - - /** - * Skips over fields until it finds the end of a given group. - * @param {!BufferDecoder} bufferDecoder - * @param {number} start - * @return {number} - * @private - */ - getGroupLength_(bufferDecoder, start) { - // On a start group we need to keep skipping fields until we find a - // corresponding stop group - let cursor = start; - while (cursor < bufferDecoder.endIndex()) { - const tag = bufferDecoder.getUnsignedVarint32At(cursor); - const wireType = /** @type {!WireType} */ (tag & 0x07); - if (wireType === WireType.END_GROUP) { - return bufferDecoder.cursor() - start; - } - cursor = bufferDecoder.cursor() + - this.getLength_(bufferDecoder, bufferDecoder.cursor(), wireType); - } - throw new Error('No end group found'); - } - /** * Write the whole bytes as a length delimited field. * @param {number} fieldNumber diff --git a/js/experimental/runtime/kernel/writer_test.js b/js/experimental/runtime/kernel/writer_test.js index 331ab2b143..019ae1e18a 100644 --- a/js/experimental/runtime/kernel/writer_test.js +++ b/js/experimental/runtime/kernel/writer_test.js @@ -148,7 +148,7 @@ describe('Writer.writeBufferDecoder does', () => { const expected = createArrayBuffer( 0x08, /* varint start= */ 0xFF, /* varint end= */ 0x01, 0x08, 0x01); writer.writeBufferDecoder( - BufferDecoder.fromArrayBuffer(expected), 1, WireType.VARINT); + BufferDecoder.fromArrayBuffer(expected), 1, WireType.VARINT, 1); const result = writer.getAndResetResultBuffer(); expect(result).toEqual(arrayBufferSlice(expected, 1, 3)); }); @@ -159,7 +159,7 @@ describe('Writer.writeBufferDecoder does', () => { 0x09, /* fixed64 start= */ 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, /* fixed64 end= */ 0x08, 0x08, 0x01); writer.writeBufferDecoder( - BufferDecoder.fromArrayBuffer(expected), 1, WireType.FIXED64); + BufferDecoder.fromArrayBuffer(expected), 1, WireType.FIXED64, 1); const result = writer.getAndResetResultBuffer(); expect(result).toEqual(arrayBufferSlice(expected, 1, 9)); }); @@ -170,7 +170,7 @@ describe('Writer.writeBufferDecoder does', () => { 0xA, /* length= */ 0x03, /* data start= */ 0x01, 0x02, /* data end= */ 0x03, 0x08, 0x01); writer.writeBufferDecoder( - BufferDecoder.fromArrayBuffer(expected), 1, WireType.DELIMITED); + BufferDecoder.fromArrayBuffer(expected), 1, WireType.DELIMITED, 1); const result = writer.getAndResetResultBuffer(); expect(result).toEqual(arrayBufferSlice(expected, 1, 5)); }); @@ -181,7 +181,7 @@ describe('Writer.writeBufferDecoder does', () => { 0xB, /* group start= */ 0x08, 0x01, /* nested group start= */ 0x0B, /* nested group end= */ 0x0C, /* group end= */ 0x0C, 0x08, 0x01); writer.writeBufferDecoder( - BufferDecoder.fromArrayBuffer(expected), 1, WireType.START_GROUP); + BufferDecoder.fromArrayBuffer(expected), 1, WireType.START_GROUP, 1); const result = writer.getAndResetResultBuffer(); expect(result).toEqual(arrayBufferSlice(expected, 1, 6)); }); @@ -192,7 +192,7 @@ describe('Writer.writeBufferDecoder does', () => { 0x09, /* fixed64 start= */ 0x01, 0x02, 0x03, /* fixed64 end= */ 0x04, 0x08, 0x01); writer.writeBufferDecoder( - BufferDecoder.fromArrayBuffer(expected), 1, WireType.FIXED32); + BufferDecoder.fromArrayBuffer(expected), 1, WireType.FIXED32, 1); const result = writer.getAndResetResultBuffer(); expect(result).toEqual(arrayBufferSlice(expected, 1, 5)); }); @@ -203,7 +203,7 @@ describe('Writer.writeBufferDecoder does', () => { const subBuffer = arrayBufferSlice(buffer, 0, 2); expect( () => writer.writeBufferDecoder( - BufferDecoder.fromArrayBuffer(subBuffer), 0, WireType.DELIMITED)) + BufferDecoder.fromArrayBuffer(subBuffer), 0, WireType.DELIMITED, 1)) .toThrow(); }); }); diff --git a/js/experimental/runtime/testing/binary/test_message.js b/js/experimental/runtime/testing/binary/test_message.js index a7aa8a1522..cfd264b324 100644 --- a/js/experimental/runtime/testing/binary/test_message.js +++ b/js/experimental/runtime/testing/binary/test_message.js @@ -13,6 +13,13 @@ const Kernel = goog.require('protobuf.runtime.Kernel'); * @implements {InternalMessage} */ class TestMessage { + /** + * @return {!TestMessage} + */ + static createEmpty() { + return TestMessage.instanceCreator(Kernel.createEmpty()); + } + /** * @param {!Kernel} kernel * @return {!TestMessage} @@ -31,7 +38,6 @@ class TestMessage { /** * @override - * @package * @return {!Kernel} */ internalGetKernel() { @@ -1760,4 +1766,4 @@ class TestMessage { } } -exports = TestMessage; \ No newline at end of file +exports = TestMessage; diff --git a/protobuf.bzl b/protobuf.bzl index e2821f5b5e..8d67620f83 100644 --- a/protobuf.bzl +++ b/protobuf.bzl @@ -432,7 +432,7 @@ def py_proto_library( protoc: the label of the protocol compiler to generate the sources. use_grpc_plugin: a flag to indicate whether to call the Python C++ plugin when processing the proto files. - **kargs: other keyword arguments that are passed to cc_library. + **kargs: other keyword arguments that are passed to py_library. """ outs = _PyOuts(srcs, use_grpc_plugin) diff --git a/python/google/protobuf/internal/test_proto3_optional.proto b/python/google/protobuf/internal/test_proto3_optional.proto index a5abd19d13..f3e0a2e761 100644 --- a/python/google/protobuf/internal/test_proto3_optional.proto +++ b/python/google/protobuf/internal/test_proto3_optional.proto @@ -30,10 +30,7 @@ syntax = "proto3"; -package protobuf_unittest; - -option java_multiple_files = true; -option java_package = "com.google.protobuf.testing.proto"; +package google.protobuf.python.internal; message TestProto3Optional { message NestedMessage { diff --git a/src/google/protobuf/api.pb.h b/src/google/protobuf/api.pb.h index 66ac61af89..21d57a7f4e 100644 --- a/src/google/protobuf/api.pb.h +++ b/src/google/protobuf/api.pb.h @@ -1063,7 +1063,9 @@ inline void Api::unsafe_arena_set_allocated_source_context( // @@protoc_insertion_point(field_unsafe_arena_set_allocated:google.protobuf.Api.source_context) } inline PROTOBUF_NAMESPACE_ID::SourceContext* Api::release_source_context() { - auto temp = unsafe_arena_release_source_context(); + + PROTOBUF_NAMESPACE_ID::SourceContext* temp = source_context_; + source_context_ = nullptr; if (GetArena() != nullptr) { temp = ::PROTOBUF_NAMESPACE_ID::internal::DuplicateIfNonNull(temp); } diff --git a/src/google/protobuf/compiler/cpp/cpp_message.cc b/src/google/protobuf/compiler/cpp/cpp_message.cc index 1cc089196e..b9ba22fda7 100644 --- a/src/google/protobuf/compiler/cpp/cpp_message.cc +++ b/src/google/protobuf/compiler/cpp/cpp_message.cc @@ -614,7 +614,6 @@ MessageGenerator::MessageGenerator( } } - if (!has_bit_indices_.empty()) { field_generators_.SetHasBitIndices(has_bit_indices_); } diff --git a/src/google/protobuf/compiler/cpp/cpp_message_field.cc b/src/google/protobuf/compiler/cpp/cpp_message_field.cc index b6b8f24b6a..38fcb52e6c 100644 --- a/src/google/protobuf/compiler/cpp/cpp_message_field.cc +++ b/src/google/protobuf/compiler/cpp/cpp_message_field.cc @@ -199,7 +199,10 @@ void MessageFieldGenerator::GenerateInlineAccessorDefinitions( "}\n"); format( "inline $type$* $classname$::$release_name$() {\n" - " auto temp = unsafe_arena_release_$name$();\n" + "$type_reference_function$" + " $clear_hasbit$\n" + " $type$* temp = $casted_member$;\n" + " $name$_ = nullptr;\n" " if (GetArena() != nullptr) {\n" " temp = ::$proto_ns$::internal::DuplicateIfNonNull(temp);\n" " }\n" diff --git a/src/google/protobuf/compiler/java/java_enum_field.cc b/src/google/protobuf/compiler/java/java_enum_field.cc index 6322ee566e..32cff15fec 100644 --- a/src/google/protobuf/compiler/java/java_enum_field.cc +++ b/src/google/protobuf/compiler/java/java_enum_field.cc @@ -225,6 +225,7 @@ void ImmutableEnumFieldGenerator::GenerateBuilderMembers( printer->Print(variables_, "$deprecation$public Builder " "${$set$capitalized_name$Value$}$(int value) {\n" + " $set_has_field_bit_builder$\n" " $name$_ = value;\n" " $on_changed$\n" " return this;\n" diff --git a/src/google/protobuf/compiler/plugin.pb.h b/src/google/protobuf/compiler/plugin.pb.h index 80ebc07a7d..2ae03620c5 100644 --- a/src/google/protobuf/compiler/plugin.pb.h +++ b/src/google/protobuf/compiler/plugin.pb.h @@ -1441,7 +1441,9 @@ inline void CodeGeneratorRequest::unsafe_arena_set_allocated_compiler_version( // @@protoc_insertion_point(field_unsafe_arena_set_allocated:google.protobuf.compiler.CodeGeneratorRequest.compiler_version) } inline PROTOBUF_NAMESPACE_ID::compiler::Version* CodeGeneratorRequest::release_compiler_version() { - auto temp = unsafe_arena_release_compiler_version(); + _has_bits_[0] &= ~0x00000002u; + PROTOBUF_NAMESPACE_ID::compiler::Version* temp = compiler_version_; + compiler_version_ = nullptr; if (GetArena() != nullptr) { temp = ::PROTOBUF_NAMESPACE_ID::internal::DuplicateIfNonNull(temp); } diff --git a/src/google/protobuf/descriptor.pb.h b/src/google/protobuf/descriptor.pb.h index 4a6dbf489b..f8fc5544a4 100644 --- a/src/google/protobuf/descriptor.pb.h +++ b/src/google/protobuf/descriptor.pb.h @@ -7484,7 +7484,9 @@ inline void FileDescriptorProto::unsafe_arena_set_allocated_options( // @@protoc_insertion_point(field_unsafe_arena_set_allocated:google.protobuf.FileDescriptorProto.options) } inline PROTOBUF_NAMESPACE_ID::FileOptions* FileDescriptorProto::release_options() { - auto temp = unsafe_arena_release_options(); + _has_bits_[0] &= ~0x00000008u; + PROTOBUF_NAMESPACE_ID::FileOptions* temp = options_; + options_ = nullptr; if (GetArena() != nullptr) { temp = ::PROTOBUF_NAMESPACE_ID::internal::DuplicateIfNonNull(temp); } @@ -7565,7 +7567,9 @@ inline void FileDescriptorProto::unsafe_arena_set_allocated_source_code_info( // @@protoc_insertion_point(field_unsafe_arena_set_allocated:google.protobuf.FileDescriptorProto.source_code_info) } inline PROTOBUF_NAMESPACE_ID::SourceCodeInfo* FileDescriptorProto::release_source_code_info() { - auto temp = unsafe_arena_release_source_code_info(); + _has_bits_[0] &= ~0x00000010u; + PROTOBUF_NAMESPACE_ID::SourceCodeInfo* temp = source_code_info_; + source_code_info_ = nullptr; if (GetArena() != nullptr) { temp = ::PROTOBUF_NAMESPACE_ID::internal::DuplicateIfNonNull(temp); } @@ -7799,7 +7803,9 @@ inline void DescriptorProto_ExtensionRange::unsafe_arena_set_allocated_options( // @@protoc_insertion_point(field_unsafe_arena_set_allocated:google.protobuf.DescriptorProto.ExtensionRange.options) } inline PROTOBUF_NAMESPACE_ID::ExtensionRangeOptions* DescriptorProto_ExtensionRange::release_options() { - auto temp = unsafe_arena_release_options(); + _has_bits_[0] &= ~0x00000001u; + PROTOBUF_NAMESPACE_ID::ExtensionRangeOptions* temp = options_; + options_ = nullptr; if (GetArena() != nullptr) { temp = ::PROTOBUF_NAMESPACE_ID::internal::DuplicateIfNonNull(temp); } @@ -8271,7 +8277,9 @@ inline void DescriptorProto::unsafe_arena_set_allocated_options( // @@protoc_insertion_point(field_unsafe_arena_set_allocated:google.protobuf.DescriptorProto.options) } inline PROTOBUF_NAMESPACE_ID::MessageOptions* DescriptorProto::release_options() { - auto temp = unsafe_arena_release_options(); + _has_bits_[0] &= ~0x00000002u; + PROTOBUF_NAMESPACE_ID::MessageOptions* temp = options_; + options_ = nullptr; if (GetArena() != nullptr) { temp = ::PROTOBUF_NAMESPACE_ID::internal::DuplicateIfNonNull(temp); } @@ -9091,7 +9099,9 @@ inline void FieldDescriptorProto::unsafe_arena_set_allocated_options( // @@protoc_insertion_point(field_unsafe_arena_set_allocated:google.protobuf.FieldDescriptorProto.options) } inline PROTOBUF_NAMESPACE_ID::FieldOptions* FieldDescriptorProto::release_options() { - auto temp = unsafe_arena_release_options(); + _has_bits_[0] &= ~0x00000020u; + PROTOBUF_NAMESPACE_ID::FieldOptions* temp = options_; + options_ = nullptr; if (GetArena() != nullptr) { temp = ::PROTOBUF_NAMESPACE_ID::internal::DuplicateIfNonNull(temp); } @@ -9297,7 +9307,9 @@ inline void OneofDescriptorProto::unsafe_arena_set_allocated_options( // @@protoc_insertion_point(field_unsafe_arena_set_allocated:google.protobuf.OneofDescriptorProto.options) } inline PROTOBUF_NAMESPACE_ID::OneofOptions* OneofDescriptorProto::release_options() { - auto temp = unsafe_arena_release_options(); + _has_bits_[0] &= ~0x00000002u; + PROTOBUF_NAMESPACE_ID::OneofOptions* temp = options_; + options_ = nullptr; if (GetArena() != nullptr) { temp = ::PROTOBUF_NAMESPACE_ID::internal::DuplicateIfNonNull(temp); } @@ -9574,7 +9586,9 @@ inline void EnumDescriptorProto::unsafe_arena_set_allocated_options( // @@protoc_insertion_point(field_unsafe_arena_set_allocated:google.protobuf.EnumDescriptorProto.options) } inline PROTOBUF_NAMESPACE_ID::EnumOptions* EnumDescriptorProto::release_options() { - auto temp = unsafe_arena_release_options(); + _has_bits_[0] &= ~0x00000002u; + PROTOBUF_NAMESPACE_ID::EnumOptions* temp = options_; + options_ = nullptr; if (GetArena() != nullptr) { temp = ::PROTOBUF_NAMESPACE_ID::internal::DuplicateIfNonNull(temp); } @@ -9893,7 +9907,9 @@ inline void EnumValueDescriptorProto::unsafe_arena_set_allocated_options( // @@protoc_insertion_point(field_unsafe_arena_set_allocated:google.protobuf.EnumValueDescriptorProto.options) } inline PROTOBUF_NAMESPACE_ID::EnumValueOptions* EnumValueDescriptorProto::release_options() { - auto temp = unsafe_arena_release_options(); + _has_bits_[0] &= ~0x00000002u; + PROTOBUF_NAMESPACE_ID::EnumValueOptions* temp = options_; + options_ = nullptr; if (GetArena() != nullptr) { temp = ::PROTOBUF_NAMESPACE_ID::internal::DuplicateIfNonNull(temp); } @@ -10110,7 +10126,9 @@ inline void ServiceDescriptorProto::unsafe_arena_set_allocated_options( // @@protoc_insertion_point(field_unsafe_arena_set_allocated:google.protobuf.ServiceDescriptorProto.options) } inline PROTOBUF_NAMESPACE_ID::ServiceOptions* ServiceDescriptorProto::release_options() { - auto temp = unsafe_arena_release_options(); + _has_bits_[0] &= ~0x00000002u; + PROTOBUF_NAMESPACE_ID::ServiceOptions* temp = options_; + options_ = nullptr; if (GetArena() != nullptr) { temp = ::PROTOBUF_NAMESPACE_ID::internal::DuplicateIfNonNull(temp); } @@ -10474,7 +10492,9 @@ inline void MethodDescriptorProto::unsafe_arena_set_allocated_options( // @@protoc_insertion_point(field_unsafe_arena_set_allocated:google.protobuf.MethodDescriptorProto.options) } inline PROTOBUF_NAMESPACE_ID::MethodOptions* MethodDescriptorProto::release_options() { - auto temp = unsafe_arena_release_options(); + _has_bits_[0] &= ~0x00000008u; + PROTOBUF_NAMESPACE_ID::MethodOptions* temp = options_; + options_ = nullptr; if (GetArena() != nullptr) { temp = ::PROTOBUF_NAMESPACE_ID::internal::DuplicateIfNonNull(temp); } diff --git a/src/google/protobuf/descriptor_database.cc b/src/google/protobuf/descriptor_database.cc index 30f8bdec58..256ae48b30 100644 --- a/src/google/protobuf/descriptor_database.cc +++ b/src/google/protobuf/descriptor_database.cc @@ -146,6 +146,53 @@ bool SimpleDescriptorDatabase::DescriptorIndex::AddFile( return true; } +namespace { + +// Returns true if and only if all characters in the name are alphanumerics, +// underscores, or periods. +bool ValidateSymbolName(StringPiece name) { + for (char c : name) { + // I don't trust ctype.h due to locales. :( + if (c != '.' && c != '_' && (c < '0' || c > '9') && (c < 'A' || c > 'Z') && + (c < 'a' || c > 'z')) { + return false; + } + } + return true; +} + +// Find the last key in the container which sorts less than or equal to the +// symbol name. Since upper_bound() returns the *first* key that sorts +// *greater* than the input, we want the element immediately before that. +template +typename Container::const_iterator FindLastLessOrEqual(Container* container, + const Key& key) { + auto iter = container->upper_bound(key); + if (iter != container->begin()) --iter; + return iter; +} + +// As above, but using std::upper_bound instead. +template +typename Container::const_iterator FindLastLessOrEqual(Container* container, + const Key& key, + const Cmp& cmp) { + auto iter = std::upper_bound(container->begin(), container->end(), key, cmp); + if (iter != container->begin()) --iter; + return iter; +} + +// True if either the arguments are equal or super_symbol identifies a +// parent symbol of sub_symbol (e.g. "foo.bar" is a parent of +// "foo.bar.baz", but not a parent of "foo.barbaz"). +bool IsSubSymbol(StringPiece sub_symbol, StringPiece super_symbol) { + return sub_symbol == super_symbol || + (HasPrefixString(super_symbol, sub_symbol) && + super_symbol[sub_symbol.size()] == '.'); +} + +} // namespace + template bool SimpleDescriptorDatabase::DescriptorIndex::AddSymbol( const std::string& name, Value value) { @@ -161,8 +208,7 @@ bool SimpleDescriptorDatabase::DescriptorIndex::AddSymbol( // Try to look up the symbol to make sure a super-symbol doesn't already // exist. - typename std::map::iterator iter = - FindLastLessOrEqual(name); + auto iter = FindLastLessOrEqual(&by_symbol_, name); if (iter == by_symbol_.end()) { // Apparently the map is currently empty. Just insert and be done with it. @@ -252,8 +298,7 @@ Value SimpleDescriptorDatabase::DescriptorIndex::FindFile( template Value SimpleDescriptorDatabase::DescriptorIndex::FindSymbol( const std::string& name) { - typename std::map::iterator iter = - FindLastLessOrEqual(name); + auto iter = FindLastLessOrEqual(&by_symbol_, name); return (iter != by_symbol_.end() && IsSubSymbol(iter->first, name)) ? iter->second @@ -294,40 +339,6 @@ void SimpleDescriptorDatabase::DescriptorIndex::FindAllFileNames( } } -template -typename std::map::iterator -SimpleDescriptorDatabase::DescriptorIndex::FindLastLessOrEqual( - const std::string& name) { - // Find the last key in the map which sorts less than or equal to the - // symbol name. Since upper_bound() returns the *first* key that sorts - // *greater* than the input, we want the element immediately before that. - typename std::map::iterator iter = - by_symbol_.upper_bound(name); - if (iter != by_symbol_.begin()) --iter; - return iter; -} - -template -bool SimpleDescriptorDatabase::DescriptorIndex::IsSubSymbol( - const std::string& sub_symbol, const std::string& super_symbol) { - return sub_symbol == super_symbol || - (HasPrefixString(super_symbol, sub_symbol) && - super_symbol[sub_symbol.size()] == '.'); -} - -template -bool SimpleDescriptorDatabase::DescriptorIndex::ValidateSymbolName( - const std::string& name) { - for (int i = 0; i < name.size(); i++) { - // I don't trust ctype.h due to locales. :( - if (name[i] != '.' && name[i] != '_' && (name[i] < '0' || name[i] > '9') && - (name[i] < 'A' || name[i] > 'Z') && (name[i] < 'a' || name[i] > 'z')) { - return false; - } - } - return true; -} - // ------------------------------------------------------------------- bool SimpleDescriptorDatabase::Add(const FileDescriptorProto& file) { @@ -378,19 +389,88 @@ bool SimpleDescriptorDatabase::MaybeCopy(const FileDescriptorProto* file, // ------------------------------------------------------------------- -EncodedDescriptorDatabase::EncodedDescriptorDatabase() {} -EncodedDescriptorDatabase::~EncodedDescriptorDatabase() { - for (int i = 0; i < files_to_delete_.size(); i++) { - operator delete(files_to_delete_[i]); - } -} +class EncodedDescriptorDatabase::DescriptorIndex { + public: + using Value = std::pair; + // Helpers to recursively add particular descriptors and all their contents + // to the index. + bool AddFile(const FileDescriptorProto& file, Value value); + + Value FindFile(StringPiece filename); + Value FindSymbol(StringPiece name); + Value FindSymbolOnlyFlat(StringPiece name) const; + Value FindExtension(StringPiece containing_type, int field_number); + bool FindAllExtensionNumbers(StringPiece containing_type, + std::vector* output); + void FindAllFileNames(std::vector* output) const; + + private: + friend class EncodedDescriptorDatabase; + + bool AddSymbol(StringPiece name, Value value); + bool AddNestedExtensions(StringPiece filename, + const DescriptorProto& message_type, Value value); + bool AddExtension(StringPiece filename, + const FieldDescriptorProto& field, Value value); + + // All the maps below have two representations: + // - a std::set<> where we insert initially. + // - a std::vector<> where we flatten the structure on demand. + // The initial tree helps avoid O(N) behavior of inserting into a sorted + // vector, while the vector reduces the heap requirements of the data + // structure. + + void EnsureFlat(); + + struct Entry { + std::string name; + Value data; + }; + struct Compare { + bool operator()(const Entry& a, const Entry& b) const { + return a.name < b.name; + } + bool operator()(const Entry& a, StringPiece b) const { + return a.name < b; + } + bool operator()(StringPiece a, const Entry& b) const { + return a < b.name; + } + }; + std::set by_name_; + std::vector by_name_flat_; + std::set by_symbol_; + std::vector by_symbol_flat_; + struct ExtensionEntry { + std::string extendee; + int extension_number; + Value data; + }; + struct ExtensionCompare { + bool operator()(const ExtensionEntry& a, const ExtensionEntry& b) const { + return std::tie(a.extendee, a.extension_number) < + std::tie(b.extendee, b.extension_number); + } + bool operator()(const ExtensionEntry& a, + std::tuple b) const { + return std::tie(a.extendee, a.extension_number) < b; + } + bool operator()(std::tuple a, + const ExtensionEntry& b) const { + return a < std::tie(b.extendee, b.extension_number); + } + }; + std::set by_extension_; + std::vector by_extension_flat_; +}; bool EncodedDescriptorDatabase::Add(const void* encoded_file_descriptor, int size) { google::protobuf::Arena arena; auto* file = google::protobuf::Arena::CreateMessage(&arena); if (file->ParseFromArray(encoded_file_descriptor, size)) { - return index_.AddFile(*file, std::make_pair(encoded_file_descriptor, size)); + return index_->AddFile(*file, + std::make_pair(encoded_file_descriptor, size)); } else { GOOGLE_LOG(ERROR) << "Invalid file descriptor data passed to " "EncodedDescriptorDatabase::Add()."; @@ -408,22 +488,22 @@ bool EncodedDescriptorDatabase::AddCopy(const void* encoded_file_descriptor, bool EncodedDescriptorDatabase::FindFileByName(const std::string& filename, FileDescriptorProto* output) { - return MaybeParse(index_.FindFile(filename), output); + return MaybeParse(index_->FindFile(filename), output); } bool EncodedDescriptorDatabase::FindFileContainingSymbol( const std::string& symbol_name, FileDescriptorProto* output) { - return MaybeParse(index_.FindSymbol(symbol_name), output); + return MaybeParse(index_->FindSymbol(symbol_name), output); } bool EncodedDescriptorDatabase::FindNameOfFileContainingSymbol( const std::string& symbol_name, std::string* output) { - std::pair encoded_file = index_.FindSymbol(symbol_name); + auto encoded_file = index_->FindSymbol(symbol_name); if (encoded_file.first == NULL) return false; // Optimization: The name should be the first field in the encoded message. // Try to just read it directly. - io::CodedInputStream input(reinterpret_cast(encoded_file.first), + io::CodedInputStream input(static_cast(encoded_file.first), encoded_file.second); const uint32 kNameTag = internal::WireFormatLite::MakeTag( @@ -447,18 +527,245 @@ bool EncodedDescriptorDatabase::FindNameOfFileContainingSymbol( bool EncodedDescriptorDatabase::FindFileContainingExtension( const std::string& containing_type, int field_number, FileDescriptorProto* output) { - return MaybeParse(index_.FindExtension(containing_type, field_number), + return MaybeParse(index_->FindExtension(containing_type, field_number), output); } bool EncodedDescriptorDatabase::FindAllExtensionNumbers( const std::string& extendee_type, std::vector* output) { - return index_.FindAllExtensionNumbers(extendee_type, output); + return index_->FindAllExtensionNumbers(extendee_type, output); +} + +bool EncodedDescriptorDatabase::DescriptorIndex::AddFile( + const FileDescriptorProto& file, Value value) { + if (!InsertIfNotPresent(&by_name_, Entry{file.name(), value}) || + std::binary_search(by_name_flat_.begin(), by_name_flat_.end(), + file.name(), by_name_.key_comp())) { + GOOGLE_LOG(ERROR) << "File already exists in database: " << file.name(); + return false; + } + + // We must be careful here -- calling file.package() if file.has_package() is + // false could access an uninitialized static-storage variable if we are being + // run at startup time. + std::string path = file.has_package() ? file.package() : std::string(); + if (!path.empty()) path += '.'; + + for (const auto& message_type : file.message_type()) { + if (!AddSymbol(path + message_type.name(), value)) return false; + if (!AddNestedExtensions(file.name(), message_type, value)) return false; + } + for (const auto& enum_type : file.enum_type()) { + if (!AddSymbol(path + enum_type.name(), value)) return false; + } + for (const auto& extension : file.extension()) { + if (!AddSymbol(path + extension.name(), value)) return false; + if (!AddExtension(file.name(), extension, value)) return false; + } + for (const auto& service : file.service()) { + if (!AddSymbol(path + service.name(), value)) return false; + } + + return true; } +template +static bool CheckForMutualSubsymbols(StringPiece symbol_name, Iter* iter, + Iter2 end) { + if (*iter != end) { + if (IsSubSymbol((*iter)->name, symbol_name)) { + GOOGLE_LOG(ERROR) << "Symbol name \"" << symbol_name + << "\" conflicts with the existing symbol \"" << (*iter)->name + << "\"."; + return false; + } + + // OK, that worked. Now we have to make sure that no symbol in the map is + // a sub-symbol of the one we are inserting. The only symbol which could + // be so is the first symbol that is greater than the new symbol. Since + // |iter| points at the last symbol that is less than or equal, we just have + // to increment it. + ++*iter; + + if (*iter != end && IsSubSymbol(symbol_name, (*iter)->name)) { + GOOGLE_LOG(ERROR) << "Symbol name \"" << symbol_name + << "\" conflicts with the existing symbol \"" << (*iter)->name + << "\"."; + return false; + } + } + return true; +} + +bool EncodedDescriptorDatabase::DescriptorIndex::AddSymbol( + StringPiece name, Value value) { + // We need to make sure not to violate our map invariant. + + // If the symbol name is invalid it could break our lookup algorithm (which + // relies on the fact that '.' sorts before all other characters that are + // valid in symbol names). + if (!ValidateSymbolName(name)) { + GOOGLE_LOG(ERROR) << "Invalid symbol name: " << name; + return false; + } + + Entry entry = {std::string(name), value}; + + auto iter = FindLastLessOrEqual(&by_symbol_, entry); + if (!CheckForMutualSubsymbols(name, &iter, by_symbol_.end())) { + return false; + } + + // Same, but on by_symbol_flat_ + auto flat_iter = + FindLastLessOrEqual(&by_symbol_flat_, name, by_symbol_.key_comp()); + if (!CheckForMutualSubsymbols(name, &flat_iter, by_symbol_flat_.end())) { + return false; + } + + // OK, no conflicts. + + // Insert the new symbol using the iterator as a hint, the new entry will + // appear immediately before the one the iterator is pointing at. + by_symbol_.insert(iter, std::move(entry)); + + return true; +} + +bool EncodedDescriptorDatabase::DescriptorIndex::AddNestedExtensions( + StringPiece filename, const DescriptorProto& message_type, + Value value) { + for (const auto& nested_type : message_type.nested_type()) { + if (!AddNestedExtensions(filename, nested_type, value)) return false; + } + for (const auto& extension : message_type.extension()) { + if (!AddExtension(filename, extension, value)) return false; + } + return true; +} + +bool EncodedDescriptorDatabase::DescriptorIndex::AddExtension( + StringPiece filename, const FieldDescriptorProto& field, + Value value) { + if (!field.extendee().empty() && field.extendee()[0] == '.') { + // The extension is fully-qualified. We can use it as a lookup key in + // the by_symbol_ table. + if (!InsertIfNotPresent(&by_extension_, + ExtensionEntry{field.extendee().substr(1), + field.number(), value}) || + std::binary_search( + by_extension_flat_.begin(), by_extension_flat_.end(), + std::make_pair(field.extendee().substr(1), field.number()), + by_extension_.key_comp())) { + GOOGLE_LOG(ERROR) << "Extension conflicts with extension already in database: " + "extend " + << field.extendee() << " { " << field.name() << " = " + << field.number() << " } from:" << filename; + return false; + } + } else { + // Not fully-qualified. We can't really do anything here, unfortunately. + // We don't consider this an error, though, because the descriptor is + // valid. + } + return true; +} + +std::pair +EncodedDescriptorDatabase::DescriptorIndex::FindSymbol(StringPiece name) { + EnsureFlat(); + return FindSymbolOnlyFlat(name); +} + +std::pair +EncodedDescriptorDatabase::DescriptorIndex::FindSymbolOnlyFlat( + StringPiece name) const { + auto iter = + FindLastLessOrEqual(&by_symbol_flat_, name, by_symbol_.key_comp()); + + return iter != by_symbol_flat_.end() && IsSubSymbol(iter->name, name) + ? iter->data + : Value(); +} + +std::pair +EncodedDescriptorDatabase::DescriptorIndex::FindExtension( + StringPiece containing_type, int field_number) { + EnsureFlat(); + + auto it = std::lower_bound( + by_extension_flat_.begin(), by_extension_flat_.end(), + std::make_tuple(containing_type, field_number), by_extension_.key_comp()); + return it == by_extension_flat_.end() || it->extendee != containing_type || + it->extension_number != field_number + ? std::make_pair(nullptr, 0) + : it->data; +} + +template +static void MergeIntoFlat(std::set* s, std::vector* flat) { + if (s->empty()) return; + std::vector new_flat(s->size() + flat->size()); + std::merge(s->begin(), s->end(), flat->begin(), flat->end(), &new_flat[0], + s->key_comp()); + *flat = std::move(new_flat); + s->clear(); +} + +void EncodedDescriptorDatabase::DescriptorIndex::EnsureFlat() { + // Merge each of the sets into their flat counterpart. + MergeIntoFlat(&by_name_, &by_name_flat_); + MergeIntoFlat(&by_symbol_, &by_symbol_flat_); + MergeIntoFlat(&by_extension_, &by_extension_flat_); +} + +bool EncodedDescriptorDatabase::DescriptorIndex::FindAllExtensionNumbers( + StringPiece containing_type, std::vector* output) { + EnsureFlat(); + + bool success = false; + auto it = std::lower_bound( + by_extension_flat_.begin(), by_extension_flat_.end(), + std::make_tuple(containing_type, 0), by_extension_.key_comp()); + for (; it != by_extension_flat_.end() && it->extendee == containing_type; + ++it) { + output->push_back(it->extension_number); + success = true; + } + + return success; +} + +void EncodedDescriptorDatabase::DescriptorIndex::FindAllFileNames( + std::vector* output) const { + output->resize(by_name_.size() + by_name_flat_.size()); + int i = 0; + for (const auto& entry : by_name_) { + (*output)[i] = entry.name; + i++; + } + for (const auto& entry : by_name_flat_) { + (*output)[i] = entry.name; + i++; + } +} + +std::pair +EncodedDescriptorDatabase::DescriptorIndex::FindFile( + StringPiece filename) { + EnsureFlat(); + + auto it = std::lower_bound(by_name_flat_.begin(), by_name_flat_.end(), + filename, by_name_.key_comp()); + return it == by_name_flat_.end() || it->name != filename + ? std::make_pair(nullptr, 0) + : it->data; +} + + bool EncodedDescriptorDatabase::FindAllFileNames( std::vector* output) { - index_.FindAllFileNames(output); + index_->FindAllFileNames(output); return true; } @@ -468,6 +775,15 @@ bool EncodedDescriptorDatabase::MaybeParse( return output->ParseFromArray(encoded_file.first, encoded_file.second); } +EncodedDescriptorDatabase::EncodedDescriptorDatabase() + : index_(new DescriptorIndex()) {} + +EncodedDescriptorDatabase::~EncodedDescriptorDatabase() { + for (void* p : files_to_delete_) { + operator delete(p); + } +} + // =================================================================== DescriptorPoolDatabase::DescriptorPoolDatabase(const DescriptorPool& pool) diff --git a/src/google/protobuf/descriptor_database.h b/src/google/protobuf/descriptor_database.h index 4009f339e2..10e60fc535 100644 --- a/src/google/protobuf/descriptor_database.h +++ b/src/google/protobuf/descriptor_database.h @@ -266,21 +266,6 @@ class PROTOBUF_EXPORT SimpleDescriptorDatabase : public DescriptorDatabase { // That symbol cannot be a super-symbol of the search key since if it were, // then it would be a match, and we're assuming the match key doesn't exist. // Therefore, step 2 will correctly return no match. - - // Find the last entry in the by_symbol_ map whose key is less than or - // equal to the given name. - typename std::map::iterator FindLastLessOrEqual( - const std::string& name); - - // True if either the arguments are equal or super_symbol identifies a - // parent symbol of sub_symbol (e.g. "foo.bar" is a parent of - // "foo.bar.baz", but not a parent of "foo.barbaz"). - bool IsSubSymbol(const std::string& sub_symbol, - const std::string& super_symbol); - - // Returns true if and only if all characters in the name are alphanumerics, - // underscores, or periods. - bool ValidateSymbolName(const std::string& name); }; DescriptorIndex index_; @@ -332,8 +317,10 @@ class PROTOBUF_EXPORT EncodedDescriptorDatabase : public DescriptorDatabase { bool FindAllFileNames(std::vector* output) override; private: - SimpleDescriptorDatabase::DescriptorIndex > - index_; + class DescriptorIndex; + // Keep DescriptorIndex by pointer to hide the implementation to keep a + // cleaner header. + std::unique_ptr index_; std::vector files_to_delete_; // If encoded_file.first is non-NULL, parse the data into *output and return diff --git a/src/google/protobuf/io/printer.cc b/src/google/protobuf/io/printer.cc index f8d4d2b1f9..95b03f474b 100644 --- a/src/google/protobuf/io/printer.cc +++ b/src/google/protobuf/io/printer.cc @@ -305,7 +305,7 @@ void Printer::FormatInternal(const std::vector& args, } const char* Printer::WriteVariable( - const std::vector& args, + const std::vector& args, const std::map& vars, const char* format, int* arg_index, std::vector* annotations) { auto start = format; diff --git a/src/google/protobuf/io/printer_unittest.cc b/src/google/protobuf/io/printer_unittest.cc index ca45d67813..ed54d1dd1d 100644 --- a/src/google/protobuf/io/printer_unittest.cc +++ b/src/google/protobuf/io/printer_unittest.cc @@ -573,7 +573,7 @@ TEST(Printer, WriteFailurePartial) { EXPECT_TRUE(printer.failed()); // Buffer should contain the first 17 bytes written. - EXPECT_EQ("0123456789abcdef<", string(buffer, sizeof(buffer))); + EXPECT_EQ("0123456789abcdef<", std::string(buffer, sizeof(buffer))); } TEST(Printer, WriteFailureExact) { @@ -595,7 +595,7 @@ TEST(Printer, WriteFailureExact) { EXPECT_TRUE(printer.failed()); // Buffer should contain the first 16 bytes written. - EXPECT_EQ("0123456789abcdef", string(buffer, sizeof(buffer))); + EXPECT_EQ("0123456789abcdef", std::string(buffer, sizeof(buffer))); } TEST(Printer, FormatInternal) { diff --git a/src/google/protobuf/type.pb.h b/src/google/protobuf/type.pb.h index 680679f0d9..9b341fc045 100644 --- a/src/google/protobuf/type.pb.h +++ b/src/google/protobuf/type.pb.h @@ -1689,7 +1689,9 @@ inline void Type::unsafe_arena_set_allocated_source_context( // @@protoc_insertion_point(field_unsafe_arena_set_allocated:google.protobuf.Type.source_context) } inline PROTOBUF_NAMESPACE_ID::SourceContext* Type::release_source_context() { - auto temp = unsafe_arena_release_source_context(); + + PROTOBUF_NAMESPACE_ID::SourceContext* temp = source_context_; + source_context_ = nullptr; if (GetArena() != nullptr) { temp = ::PROTOBUF_NAMESPACE_ID::internal::DuplicateIfNonNull(temp); } @@ -2414,7 +2416,9 @@ inline void Enum::unsafe_arena_set_allocated_source_context( // @@protoc_insertion_point(field_unsafe_arena_set_allocated:google.protobuf.Enum.source_context) } inline PROTOBUF_NAMESPACE_ID::SourceContext* Enum::release_source_context() { - auto temp = unsafe_arena_release_source_context(); + + PROTOBUF_NAMESPACE_ID::SourceContext* temp = source_context_; + source_context_ = nullptr; if (GetArena() != nullptr) { temp = ::PROTOBUF_NAMESPACE_ID::internal::DuplicateIfNonNull(temp); } @@ -2738,7 +2742,9 @@ inline void Option::unsafe_arena_set_allocated_value( // @@protoc_insertion_point(field_unsafe_arena_set_allocated:google.protobuf.Option.value) } inline PROTOBUF_NAMESPACE_ID::Any* Option::release_value() { - auto temp = unsafe_arena_release_value(); + + PROTOBUF_NAMESPACE_ID::Any* temp = value_; + value_ = nullptr; if (GetArena() != nullptr) { temp = ::PROTOBUF_NAMESPACE_ID::internal::DuplicateIfNonNull(temp); } diff --git a/src/google/protobuf/unittest_lite.proto b/src/google/protobuf/unittest_lite.proto index 4c3d845853..652966bb83 100644 --- a/src/google/protobuf/unittest_lite.proto +++ b/src/google/protobuf/unittest_lite.proto @@ -38,7 +38,7 @@ package protobuf_unittest; import "google/protobuf/unittest_import_lite.proto"; -option cc_enable_arenas = false; +option cc_enable_arenas = true; option optimize_for = LITE_RUNTIME; option java_package = "com.google.protobuf";