diff --git a/java/core/BUILD.bazel b/java/core/BUILD.bazel index b896c426b9..30c0e90881 100644 --- a/java/core/BUILD.bazel +++ b/java/core/BUILD.bazel @@ -2,10 +2,10 @@ load("@bazel_skylib//rules:build_test.bzl", "build_test") load("@rules_java//java:defs.bzl", "java_lite_proto_library", "java_proto_library") load("@rules_pkg//:mappings.bzl", "pkg_files", "strip_prefix") load("@rules_proto//proto:defs.bzl", "proto_lang_toolchain", "proto_library") -load("//build_defs:java_opts.bzl", "protobuf_java_export", "protobuf_java_library", "protobuf_versioned_java_library") -load("//conformance:defs.bzl", "conformance_test") load("//:protobuf.bzl", "internal_gen_well_known_protos_java") load("//:protobuf_version.bzl", "PROTOBUF_JAVA_VERSION") +load("//build_defs:java_opts.bzl", "protobuf_java_export", "protobuf_java_library", "protobuf_versioned_java_library") +load("//conformance:defs.bzl", "conformance_test") load("//java/internal:testing.bzl", "junit_tests") LITE_SRCS = [ @@ -472,6 +472,7 @@ LITE_TEST_EXCLUSIONS = [ "src/test/java/com/google/protobuf/FieldPresenceTest.java", "src/test/java/com/google/protobuf/ForceFieldBuildersPreRun.java", "src/test/java/com/google/protobuf/GeneratedMessageTest.java", + "src/test/java/com/google/protobuf/LazilyParsedMessageSetTest.java", "src/test/java/com/google/protobuf/LazyFieldTest.java", "src/test/java/com/google/protobuf/LazyStringEndToEndTest.java", "src/test/java/com/google/protobuf/MapForProto2Test.java", diff --git a/java/core/src/main/java/com/google/protobuf/FieldSet.java b/java/core/src/main/java/com/google/protobuf/FieldSet.java index bb3eea993d..a8ba1bd413 100644 --- a/java/core/src/main/java/com/google/protobuf/FieldSet.java +++ b/java/core/src/main/java/com/google/protobuf/FieldSet.java @@ -173,7 +173,8 @@ final class FieldSet> { /** Get a simple map containing all the fields. */ public Map getAllFields() { if (hasLazyField) { - SmallSortedMap result = cloneAllFieldsMap(fields, /* copyList */ false); + SmallSortedMap result = + cloneAllFieldsMap(fields, /* copyList= */ false, /* resolveLazyFields= */ true); if (fields.isImmutable()) { result.makeImmutable(); } @@ -183,22 +184,22 @@ final class FieldSet> { } private static > SmallSortedMap cloneAllFieldsMap( - SmallSortedMap fields, boolean copyList) { + SmallSortedMap fields, boolean copyList, boolean resolveLazyFields) { SmallSortedMap result = SmallSortedMap.newFieldMap(DEFAULT_FIELD_MAP_ARRAY_SIZE); for (int i = 0; i < fields.getNumArrayEntries(); i++) { - cloneFieldEntry(result, fields.getArrayEntryAt(i), copyList); + cloneFieldEntry(result, fields.getArrayEntryAt(i), copyList, resolveLazyFields); } for (Map.Entry entry : fields.getOverflowEntries()) { - cloneFieldEntry(result, entry, copyList); + cloneFieldEntry(result, entry, copyList, resolveLazyFields); } return result; } private static > void cloneFieldEntry( - Map map, Map.Entry entry, boolean copyList) { + Map map, Map.Entry entry, boolean copyList, boolean resolveLazyFields) { T key = entry.getKey(); Object value = entry.getValue(); - if (value instanceof LazyField) { + if (resolveLazyFields && value instanceof LazyField) { map.put(key, ((LazyField) value).getValue()); } else if (copyList && value instanceof List) { map.put(key, new ArrayList<>((List) value)); @@ -958,7 +959,8 @@ final class FieldSet> { SmallSortedMap fieldsForBuild = fields; if (hasNestedBuilders) { // Make a copy of the fields map with all Builders replaced by Message. - fieldsForBuild = cloneAllFieldsMap(fields, /* copyList */ false); + fieldsForBuild = + cloneAllFieldsMap(fields, /* copyList= */ false, /* resolveLazyFields= */ false); replaceBuilders(fieldsForBuild, partial); } FieldSet fieldSet = new FieldSet<>(fieldsForBuild); @@ -1030,7 +1032,10 @@ final class FieldSet> { /** Returns a new Builder using the fields from {@code fieldSet}. */ public static > Builder fromFieldSet(FieldSet fieldSet) { - Builder builder = new Builder(cloneAllFieldsMap(fieldSet.fields, /* copyList */ true)); + Builder builder = + new Builder( + cloneAllFieldsMap( + fieldSet.fields, /* copyList= */ true, /* resolveLazyFields= */ false)); builder.hasLazyField = fieldSet.hasLazyField; return builder; } @@ -1040,7 +1045,8 @@ final class FieldSet> { /** Get a simple map containing all the fields. */ public Map getAllFields() { if (hasLazyField) { - SmallSortedMap result = cloneAllFieldsMap(fields, /* copyList */ false); + SmallSortedMap result = + cloneAllFieldsMap(fields, /* copyList= */ false, /* resolveLazyFields= */ true); if (fields.isImmutable()) { result.makeImmutable(); } else { @@ -1081,7 +1087,7 @@ final class FieldSet> { private void ensureIsMutable() { if (!isMutable) { - fields = cloneAllFieldsMap(fields, /* copyList */ true); + fields = cloneAllFieldsMap(fields, /* copyList= */ true, /* resolveLazyFields= */ false); isMutable = true; } } diff --git a/java/core/src/test/java/com/google/protobuf/LazilyParsedMessageSetTest.java b/java/core/src/test/java/com/google/protobuf/LazilyParsedMessageSetTest.java new file mode 100644 index 0000000000..c41a381823 --- /dev/null +++ b/java/core/src/test/java/com/google/protobuf/LazilyParsedMessageSetTest.java @@ -0,0 +1,162 @@ +// Protocol Buffers - Google's data interchange format +// Copyright 2008 Google Inc. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file or at +// https://developers.google.com/open-source/licenses/bsd + +package com.google.protobuf; + +import static com.google.common.truth.Truth.assertThat; + +import protobuf_unittest.UnittestMset.RawMessageSet; +import protobuf_unittest.UnittestMset.TestMessageSetExtension1; +import protobuf_unittest.UnittestMset.TestMessageSetExtension2; +import protobuf_unittest.UnittestMset.TestMessageSetExtension3; +import proto2_wireformat_unittest.UnittestMsetWireFormat.TestMessageSet; +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +/** Tests related to handling of MessageSets with lazily parsed extensions. */ +@RunWith(JUnit4.class) +public class LazilyParsedMessageSetTest { + private static final int TYPE_ID_1 = + TestMessageSetExtension1.getDescriptor().getExtensions().get(0).getNumber(); + private static final int TYPE_ID_2 = + TestMessageSetExtension2.getDescriptor().getExtensions().get(0).getNumber(); + private static final int TYPE_ID_3 = + TestMessageSetExtension3.getDescriptor().getExtensions().get(0).getNumber(); + private static final ByteString CORRUPTED_MESSAGE_PAYLOAD = + ByteString.copyFrom(new byte[] {(byte) 0xff}); + + @Before + public void setUp() { + ExtensionRegistryLite.setEagerlyParseMessageSets(false); + } + + @Test + public void testParseAndUpdateMessageSet_unaccessedLazyFieldsAreNotLoaded() throws Exception { + ExtensionRegistry extensionRegistry = ExtensionRegistry.newInstance(); + extensionRegistry.add(TestMessageSetExtension1.messageSetExtension); + extensionRegistry.add(TestMessageSetExtension2.messageSetExtension); + extensionRegistry.add(TestMessageSetExtension3.messageSetExtension); + + // Set up a TestMessageSet with 2 extensions. The first extension has corrupted payload + // data. The test below makes sure that we never load this extension. If we ever do, then we + // will handle the exception and replace the value with the default empty message (this behavior + // is tested below in testLoadCorruptedLazyField_getsReplacedWithEmptyMessage). Later on we + // check that when we serialize the message set, we still have corrupted payload for the first + // extension. + RawMessageSet inputRaw = + RawMessageSet.newBuilder() + .addItem( + RawMessageSet.Item.newBuilder() + .setTypeId(TYPE_ID_1) + .setMessage(CORRUPTED_MESSAGE_PAYLOAD)) + .addItem( + RawMessageSet.Item.newBuilder() + .setTypeId(TYPE_ID_2) + .setMessage( + TestMessageSetExtension2.newBuilder().setStr("foo").build().toByteString())) + .build(); + + ByteString inputData = inputRaw.toByteString(); + + // Re-parse as a TestMessageSet, so that all extensions are lazy + TestMessageSet messageSet = TestMessageSet.parseFrom(inputData, extensionRegistry); + + // Update one extension and add a new one. + TestMessageSet.Builder builder = messageSet.toBuilder(); + builder.setExtension( + TestMessageSetExtension2.messageSetExtension, + TestMessageSetExtension2.newBuilder().setStr("bar").build()); + + // Call .build() in the middle of updating the builder. This triggers a codepath that we want to + // make sure preserves lazy fields. + TestMessageSet unusedIntermediateMessageSet = builder.build(); + + builder.setExtension( + TestMessageSetExtension3.messageSetExtension, + TestMessageSetExtension3.newBuilder().setRequiredInt(666).build()); + + TestMessageSet updatedMessageSet = builder.build(); + + // Check that hasExtension call does not load lazy fields. + assertThat(updatedMessageSet.hasExtension(TestMessageSetExtension1.messageSetExtension)) + .isTrue(); + + // Serialize. The first extension should still be unloaded and will get serialized using the + // same corrupted byte array. + ByteString outputData = updatedMessageSet.toByteString(); + + // Re-parse as RawMessageSet + RawMessageSet actualRaw = + RawMessageSet.parseFrom(outputData, ExtensionRegistry.getEmptyRegistry()); + + RawMessageSet expectedRaw = + RawMessageSet.newBuilder() + .addItem( + RawMessageSet.Item.newBuilder() + .setTypeId(TYPE_ID_1) + // This is the important part -- we want to make sure that the payload of the + // 1st extensions is the same corrupted byte array. If we ever load the + // extension during our manipulations above, then we would have replaced it with + // the default empty message. + .setMessage(CORRUPTED_MESSAGE_PAYLOAD)) + .addItem( + RawMessageSet.Item.newBuilder() + .setTypeId(TYPE_ID_2) + .setMessage( + TestMessageSetExtension2.newBuilder().setStr("bar").build().toByteString())) + .addItem( + RawMessageSet.Item.newBuilder() + .setTypeId(TYPE_ID_3) + .setMessage( + TestMessageSetExtension3.newBuilder() + .setRequiredInt(666) + .build() + .toByteString())) + .build(); + + assertThat(actualRaw).isEqualTo(expectedRaw); + } + + @Test + public void testLoadCorruptedLazyField_getsReplacedWithEmptyMessage() throws Exception { + ExtensionRegistry extensionRegistry = ExtensionRegistry.newInstance(); + extensionRegistry.add(TestMessageSetExtension1.messageSetExtension); + + RawMessageSet inputRaw = + RawMessageSet.newBuilder() + .addItem( + RawMessageSet.Item.newBuilder() + .setTypeId(TYPE_ID_1) + .setMessage(CORRUPTED_MESSAGE_PAYLOAD)) + .build(); + + ByteString inputData = inputRaw.toByteString(); + + // Re-parse as a TestMessageSet, so that all extensions are lazy + TestMessageSet messageSet = TestMessageSet.parseFrom(inputData, extensionRegistry); + + assertThat(messageSet.getExtension(TestMessageSetExtension1.messageSetExtension)) + .isEqualTo(TestMessageSetExtension1.getDefaultInstance()); + + // Serialize. The first extension should be serialized as an empty message. + ByteString outputData = messageSet.toByteString(); + + // Re-parse as RawMessageSet + RawMessageSet actualRaw = + RawMessageSet.parseFrom(outputData, ExtensionRegistry.getEmptyRegistry()); + + RawMessageSet expectedRaw = + RawMessageSet.newBuilder() + .addItem( + RawMessageSet.Item.newBuilder().setTypeId(TYPE_ID_1).setMessage(ByteString.empty())) + .build(); + + assertThat(actualRaw).isEqualTo(expectedRaw); + } +} diff --git a/java/core/src/test/java/com/google/protobuf/WireFormatTest.java b/java/core/src/test/java/com/google/protobuf/WireFormatTest.java index 4afeff8f17..bbf8d0cba6 100644 --- a/java/core/src/test/java/com/google/protobuf/WireFormatTest.java +++ b/java/core/src/test/java/com/google/protobuf/WireFormatTest.java @@ -12,7 +12,6 @@ import static com.google.common.truth.Truth.assertThat; import protobuf_unittest.UnittestMset.RawMessageSet; import protobuf_unittest.UnittestMset.TestMessageSetExtension1; import protobuf_unittest.UnittestMset.TestMessageSetExtension2; -import protobuf_unittest.UnittestMset.TestMessageSetExtension3; import protobuf_unittest.UnittestProto; import protobuf_unittest.UnittestProto.TestAllExtensions; import protobuf_unittest.UnittestProto.TestAllTypes; @@ -506,73 +505,6 @@ public class WireFormatTest { .isEqualTo(123); } - @Test - public void testParseAndUpdateMessageSetExtensionEagerly() throws Exception { - testParseAndUpdateMessageSetExtensionEagerlyWithFlag(true); - } - - @Test - public void testParseAndUpdateMessageSetExtensionNotEagerly() throws Exception { - testParseAndUpdateMessageSetExtensionEagerlyWithFlag(false); - } - - private void testParseAndUpdateMessageSetExtensionEagerlyWithFlag(boolean eagerParsing) - throws Exception { - ExtensionRegistryLite.setEagerlyParseMessageSets(eagerParsing); - ExtensionRegistry extensionRegistry = ExtensionRegistry.newInstance(); - extensionRegistry.add(TestMessageSetExtension1.messageSetExtension); - extensionRegistry.add(TestMessageSetExtension2.messageSetExtension); - extensionRegistry.add(TestMessageSetExtension3.messageSetExtension); - - // Set up a RawMessageSet with 2 extensions - RawMessageSet raw = - RawMessageSet.newBuilder() - .addItem( - RawMessageSet.Item.newBuilder() - .setTypeId(TYPE_ID_1) - .setMessage( - TestMessageSetExtension1.newBuilder().setI(123).build().toByteString()) - .build()) - .addItem( - RawMessageSet.Item.newBuilder() - .setTypeId(TYPE_ID_2) - .setMessage( - TestMessageSetExtension2.newBuilder().setStr("foo").build().toByteString()) - .build()) - .build(); - - ByteString data = raw.toByteString(); - - // Parse as a TestMessageSet. - TestMessageSet messageSet = TestMessageSet.parseFrom(data, extensionRegistry); - - // Update one extension and add a new one. - TestMessageSet.Builder builder = messageSet.toBuilder(); - builder.setExtension( - TestMessageSetExtension2.messageSetExtension, - TestMessageSetExtension2.newBuilder().setStr("bar").build()); - builder.setExtension( - TestMessageSetExtension3.messageSetExtension, - TestMessageSetExtension3.newBuilder().setRequiredInt(666).build()); - - TestMessageSet updatedMessageSet = builder.build(); - // Check all 3 extensions - assertThat(updatedMessageSet.getExtension(TestMessageSetExtension1.messageSetExtension).getI()) - .isEqualTo(123); - assertThat( - updatedMessageSet.getExtension(TestMessageSetExtension2.messageSetExtension).getStr()) - .isEqualTo("bar"); - assertThat( - updatedMessageSet - .getExtension(TestMessageSetExtension3.messageSetExtension) - .getRequiredInt()) - .isEqualTo(666); - - // Serialize and re-parse, and make sure we get the same message back - assertThat(TestMessageSet.parseFrom(updatedMessageSet.toByteString(), extensionRegistry)) - .isEqualTo(updatedMessageSet); - } - // ================================================================ // oneof @Test