Merge tag 'refs/tags/sync-piper' into sync-stage

pull/9665/head
Darly Paredes 3 years ago
commit 8e3746559d
  1. 44
      java/core/src/main/java/com/google/protobuf/TextFormat.java
  2. 3
      java/core/src/test/java/com/google/protobuf/CheckUtf8Test.java
  3. 2
      java/core/src/test/java/com/google/protobuf/GeneratedMessageTest.java
  4. 8
      java/core/src/test/java/com/google/protobuf/MapForProto2LiteTest.java
  5. 8
      java/core/src/test/java/com/google/protobuf/MapForProto2Test.java
  6. 8
      java/core/src/test/java/com/google/protobuf/MapLiteTest.java
  7. 8
      java/core/src/test/java/com/google/protobuf/MapTest.java
  8. 2
      java/core/src/test/java/com/google/protobuf/ParserLiteTest.java
  9. 5
      java/core/src/test/java/com/google/protobuf/ParserTest.java
  10. 17
      python/google/protobuf/internal/unknown_fields_test.py
  11. 2
      python/google/protobuf/internal/well_known_types.py
  12. 16
      python/google/protobuf/pyext/message.cc
  13. 353
      python/google/protobuf/pyext/unknown_field_set.cc
  14. 78
      python/google/protobuf/pyext/unknown_field_set.h
  15. 13
      python/google/protobuf/unknown_fields.py
  16. 5
      src/google/protobuf/compiler/java/java_enum.cc
  17. 7
      src/google/protobuf/compiler/java/java_enum_field.cc
  18. 5
      src/google/protobuf/compiler/java/java_enum_field_lite.cc
  19. 5
      src/google/protobuf/compiler/java/java_extension.cc
  20. 5
      src/google/protobuf/compiler/java/java_extension_lite.cc
  21. 5
      src/google/protobuf/compiler/java/java_file.cc
  22. 5
      src/google/protobuf/compiler/java/java_helpers.cc
  23. 5
      src/google/protobuf/compiler/java/java_map_field.cc
  24. 5
      src/google/protobuf/compiler/java/java_map_field_lite.cc
  25. 5
      src/google/protobuf/compiler/java/java_message.cc
  26. 5
      src/google/protobuf/compiler/java/java_message_builder.cc
  27. 5
      src/google/protobuf/compiler/java/java_message_builder_lite.cc
  28. 5
      src/google/protobuf/compiler/java/java_message_field.cc
  29. 5
      src/google/protobuf/compiler/java/java_message_field_lite.cc
  30. 5
      src/google/protobuf/compiler/java/java_message_lite.cc
  31. 5
      src/google/protobuf/compiler/java/java_name_resolver.cc
  32. 5
      src/google/protobuf/compiler/java/java_name_resolver.h
  33. 5
      src/google/protobuf/compiler/java/java_service.cc
  34. 2
      src/google/protobuf/compiler/python/python_helpers.cc
  35. 201
      src/google/protobuf/compiler/python/python_pyi_generator.cc
  36. 28
      src/google/protobuf/compiler/python/python_pyi_generator.h
  37. 125
      src/google/protobuf/generated_message_tctable_lite.cc
  38. 87
      src/google/protobuf/io/tokenizer.cc
  39. 112
      src/google/protobuf/io/tokenizer_unittest.cc
  40. 13
      src/google/protobuf/port_def.inc
  41. 2
      src/google/protobuf/port_undef.inc
  42. 2
      src/google/protobuf/repeated_ptr_field.h
  43. 2
      src/google/protobuf/text_format.cc

@ -989,12 +989,12 @@ public final class TextFormat {
} }
/** Are we at the end of the input? */ /** Are we at the end of the input? */
public boolean atEnd() { boolean atEnd() {
return currentToken.length() == 0; return currentToken.length() == 0;
} }
/** Advance to the next token. */ /** Advance to the next token. */
public void nextToken() { void nextToken() {
previousLine = line; previousLine = line;
previousColumn = column; previousColumn = column;
@ -1040,7 +1040,7 @@ public final class TextFormat {
* If the next token exactly matches {@code token}, consume it and return {@code true}. * If the next token exactly matches {@code token}, consume it and return {@code true}.
* Otherwise, return {@code false} without doing anything. * Otherwise, return {@code false} without doing anything.
*/ */
public boolean tryConsume(final String token) { boolean tryConsume(final String token) {
if (currentToken.equals(token)) { if (currentToken.equals(token)) {
nextToken(); nextToken();
return true; return true;
@ -1053,14 +1053,14 @@ public final class TextFormat {
* If the next token exactly matches {@code token}, consume it. Otherwise, throw a {@link * If the next token exactly matches {@code token}, consume it. Otherwise, throw a {@link
* ParseException}. * ParseException}.
*/ */
public void consume(final String token) throws ParseException { void consume(final String token) throws ParseException {
if (!tryConsume(token)) { if (!tryConsume(token)) {
throw parseException("Expected \"" + token + "\"."); throw parseException("Expected \"" + token + "\".");
} }
} }
/** Returns {@code true} if the next token is an integer, but does not consume it. */ /** Returns {@code true} if the next token is an integer, but does not consume it. */
public boolean lookingAtInteger() { boolean lookingAtInteger() {
if (currentToken.length() == 0) { if (currentToken.length() == 0) {
return false; return false;
} }
@ -1070,7 +1070,7 @@ public final class TextFormat {
} }
/** Returns {@code true} if the current token's text is equal to that specified. */ /** Returns {@code true} if the current token's text is equal to that specified. */
public boolean lookingAt(String text) { boolean lookingAt(String text) {
return currentToken.equals(text); return currentToken.equals(text);
} }
@ -1078,7 +1078,7 @@ public final class TextFormat {
* If the next token is an identifier, consume it and return its value. Otherwise, throw a * If the next token is an identifier, consume it and return its value. Otherwise, throw a
* {@link ParseException}. * {@link ParseException}.
*/ */
public String consumeIdentifier() throws ParseException { String consumeIdentifier() throws ParseException {
for (int i = 0; i < currentToken.length(); i++) { for (int i = 0; i < currentToken.length(); i++) {
final char c = currentToken.charAt(i); final char c = currentToken.charAt(i);
if (('a' <= c && c <= 'z') if (('a' <= c && c <= 'z')
@ -1101,7 +1101,7 @@ public final class TextFormat {
* If the next token is an identifier, consume it and return {@code true}. Otherwise, return * If the next token is an identifier, consume it and return {@code true}. Otherwise, return
* {@code false} without doing anything. * {@code false} without doing anything.
*/ */
public boolean tryConsumeIdentifier() { boolean tryConsumeIdentifier() {
try { try {
consumeIdentifier(); consumeIdentifier();
return true; return true;
@ -1114,7 +1114,7 @@ public final class TextFormat {
* If the next token is a 32-bit signed integer, consume it and return its value. Otherwise, * If the next token is a 32-bit signed integer, consume it and return its value. Otherwise,
* throw a {@link ParseException}. * throw a {@link ParseException}.
*/ */
public int consumeInt32() throws ParseException { int consumeInt32() throws ParseException {
try { try {
final int result = parseInt32(currentToken); final int result = parseInt32(currentToken);
nextToken(); nextToken();
@ -1128,7 +1128,7 @@ public final class TextFormat {
* If the next token is a 32-bit unsigned integer, consume it and return its value. Otherwise, * If the next token is a 32-bit unsigned integer, consume it and return its value. Otherwise,
* throw a {@link ParseException}. * throw a {@link ParseException}.
*/ */
public int consumeUInt32() throws ParseException { int consumeUInt32() throws ParseException {
try { try {
final int result = parseUInt32(currentToken); final int result = parseUInt32(currentToken);
nextToken(); nextToken();
@ -1142,7 +1142,7 @@ public final class TextFormat {
* If the next token is a 64-bit signed integer, consume it and return its value. Otherwise, * If the next token is a 64-bit signed integer, consume it and return its value. Otherwise,
* throw a {@link ParseException}. * throw a {@link ParseException}.
*/ */
public long consumeInt64() throws ParseException { long consumeInt64() throws ParseException {
try { try {
final long result = parseInt64(currentToken); final long result = parseInt64(currentToken);
nextToken(); nextToken();
@ -1156,7 +1156,7 @@ public final class TextFormat {
* If the next token is a 64-bit signed integer, consume it and return {@code true}. Otherwise, * If the next token is a 64-bit signed integer, consume it and return {@code true}. Otherwise,
* return {@code false} without doing anything. * return {@code false} without doing anything.
*/ */
public boolean tryConsumeInt64() { boolean tryConsumeInt64() {
try { try {
consumeInt64(); consumeInt64();
return true; return true;
@ -1169,7 +1169,7 @@ public final class TextFormat {
* If the next token is a 64-bit unsigned integer, consume it and return its value. Otherwise, * If the next token is a 64-bit unsigned integer, consume it and return its value. Otherwise,
* throw a {@link ParseException}. * throw a {@link ParseException}.
*/ */
public long consumeUInt64() throws ParseException { long consumeUInt64() throws ParseException {
try { try {
final long result = parseUInt64(currentToken); final long result = parseUInt64(currentToken);
nextToken(); nextToken();
@ -1299,7 +1299,7 @@ public final class TextFormat {
} }
/** If the next token is a string, consume it and return true. Otherwise, return false. */ /** If the next token is a string, consume it and return true. Otherwise, return false. */
public boolean tryConsumeString() { boolean tryConsumeString() {
try { try {
consumeString(); consumeString();
return true; return true;
@ -1312,7 +1312,7 @@ public final class TextFormat {
* If the next token is a string, consume it, unescape it as a {@link ByteString}, and return * If the next token is a string, consume it, unescape it as a {@link ByteString}, and return
* it. Otherwise, throw a {@link ParseException}. * it. Otherwise, throw a {@link ParseException}.
*/ */
public ByteString consumeByteString() throws ParseException { ByteString consumeByteString() throws ParseException {
List<ByteString> list = new ArrayList<ByteString>(); List<ByteString> list = new ArrayList<ByteString>();
consumeByteString(list); consumeByteString(list);
while (currentToken.startsWith("'") || currentToken.startsWith("\"")) { while (currentToken.startsWith("'") || currentToken.startsWith("\"")) {
@ -1350,7 +1350,7 @@ public final class TextFormat {
* Returns a {@link ParseException} with the current line and column numbers in the description, * Returns a {@link ParseException} with the current line and column numbers in the description,
* suitable for throwing. * suitable for throwing.
*/ */
public ParseException parseException(final String description) { ParseException parseException(final String description) {
// Note: People generally prefer one-based line and column numbers. // Note: People generally prefer one-based line and column numbers.
return new ParseException(line + 1, column + 1, description); return new ParseException(line + 1, column + 1, description);
} }
@ -1359,7 +1359,7 @@ public final class TextFormat {
* Returns a {@link ParseException} with the line and column numbers of the previous token in * Returns a {@link ParseException} with the line and column numbers of the previous token in
* the description, suitable for throwing. * the description, suitable for throwing.
*/ */
public ParseException parseExceptionPreviousToken(final String description) { ParseException parseExceptionPreviousToken(final String description) {
// Note: People generally prefer one-based line and column numbers. // Note: People generally prefer one-based line and column numbers.
return new ParseException(previousLine + 1, previousColumn + 1, description); return new ParseException(previousLine + 1, previousColumn + 1, description);
} }
@ -1380,16 +1380,6 @@ public final class TextFormat {
return parseException("Couldn't parse number: " + e.getMessage()); return parseException("Couldn't parse number: " + e.getMessage());
} }
/**
* Returns a {@link UnknownFieldParseException} with the line and column numbers of the previous
* token in the description, and the unknown field name, suitable for throwing.
*/
public UnknownFieldParseException unknownFieldParseExceptionPreviousToken(
final String unknownField, final String description) {
// Note: People generally prefer one-based line and column numbers.
return new UnknownFieldParseException(
previousLine + 1, previousColumn + 1, unknownField, description);
}
} }
/** Thrown when parsing an invalid text format message. */ /** Thrown when parsing an invalid text format message. */

@ -64,8 +64,7 @@ public class CheckUtf8Test {
public void testParseRequiredStringWithGoodUtf8() throws Exception { public void testParseRequiredStringWithGoodUtf8() throws Exception {
ByteString serialized = ByteString serialized =
BytesWrapper.newBuilder().setReq(UTF8_BYTE_STRING).build().toByteString(); BytesWrapper.newBuilder().setReq(UTF8_BYTE_STRING).build().toByteString();
assertThat(StringWrapper.parser().parseFrom(serialized).getReq()) assertThat(StringWrapper.parseFrom(serialized).getReq()).isEqualTo(UTF8_BYTE_STRING_TEXT);
.isEqualTo(UTF8_BYTE_STRING_TEXT);
} }
@Test @Test

@ -359,7 +359,7 @@ public class GeneratedMessageTest {
@Test @Test
public void testParsedMessagesAreImmutable() throws Exception { public void testParsedMessagesAreImmutable() throws Exception {
TestAllTypes value = TestAllTypes.parser().parseFrom(TestUtil.getAllSet().toByteString()); TestAllTypes value = TestAllTypes.parseFrom(TestUtil.getAllSet().toByteString());
assertIsUnmodifiable(value.getRepeatedInt32List()); assertIsUnmodifiable(value.getRepeatedInt32List());
assertIsUnmodifiable(value.getRepeatedInt64List()); assertIsUnmodifiable(value.getRepeatedInt64List());
assertIsUnmodifiable(value.getRepeatedUint32List()); assertIsUnmodifiable(value.getRepeatedUint32List());

@ -392,21 +392,21 @@ public final class MapForProto2LiteTest {
setMapValues(builder); setMapValues(builder);
TestMap message = builder.build(); TestMap message = builder.build();
assertThat(message.toByteString().size()).isEqualTo(message.getSerializedSize()); assertThat(message.toByteString().size()).isEqualTo(message.getSerializedSize());
message = TestMap.parser().parseFrom(message.toByteString()); message = TestMap.parseFrom(message.toByteString());
assertMapValuesSet(message); assertMapValuesSet(message);
builder = message.toBuilder(); builder = message.toBuilder();
updateMapValues(builder); updateMapValues(builder);
message = builder.build(); message = builder.build();
assertThat(message.toByteString().size()).isEqualTo(message.getSerializedSize()); assertThat(message.toByteString().size()).isEqualTo(message.getSerializedSize());
message = TestMap.parser().parseFrom(message.toByteString()); message = TestMap.parseFrom(message.toByteString());
assertMapValuesUpdated(message); assertMapValuesUpdated(message);
builder = message.toBuilder(); builder = message.toBuilder();
builder.clear(); builder.clear();
message = builder.build(); message = builder.build();
assertThat(message.toByteString().size()).isEqualTo(message.getSerializedSize()); assertThat(message.toByteString().size()).isEqualTo(message.getSerializedSize());
message = TestMap.parser().parseFrom(message.toByteString()); message = TestMap.parseFrom(message.toByteString());
assertMapValuesCleared(message); assertMapValuesCleared(message);
} }
@ -415,7 +415,7 @@ public final class MapForProto2LiteTest {
CodedOutputStream output = CodedOutputStream.newInstance(byteArrayOutputStream); CodedOutputStream output = CodedOutputStream.newInstance(byteArrayOutputStream);
bizarroMap.writeTo(output); bizarroMap.writeTo(output);
output.flush(); output.flush();
return TestMap.parser().parseFrom(ByteString.copyFrom(byteArrayOutputStream.toByteArray())); return TestMap.parseFrom(ByteString.copyFrom(byteArrayOutputStream.toByteArray()));
} }
@Test @Test

@ -534,21 +534,21 @@ public class MapForProto2Test {
setMapValuesUsingAccessors(builder); setMapValuesUsingAccessors(builder);
TestMap message = builder.build(); TestMap message = builder.build();
assertThat(message.toByteString().size()).isEqualTo(message.getSerializedSize()); assertThat(message.toByteString().size()).isEqualTo(message.getSerializedSize());
message = TestMap.parser().parseFrom(message.toByteString()); message = TestMap.parseFrom(message.toByteString());
assertMapValuesSet(message); assertMapValuesSet(message);
builder = message.toBuilder(); builder = message.toBuilder();
updateMapValuesUsingAccessors(builder); updateMapValuesUsingAccessors(builder);
message = builder.build(); message = builder.build();
assertThat(message.toByteString().size()).isEqualTo(message.getSerializedSize()); assertThat(message.toByteString().size()).isEqualTo(message.getSerializedSize());
message = TestMap.parser().parseFrom(message.toByteString()); message = TestMap.parseFrom(message.toByteString());
assertMapValuesUpdated(message); assertMapValuesUpdated(message);
builder = message.toBuilder(); builder = message.toBuilder();
builder.clear(); builder.clear();
message = builder.build(); message = builder.build();
assertThat(message.toByteString().size()).isEqualTo(message.getSerializedSize()); assertThat(message.toByteString().size()).isEqualTo(message.getSerializedSize());
message = TestMap.parser().parseFrom(message.toByteString()); message = TestMap.parseFrom(message.toByteString());
assertMapValuesCleared(message); assertMapValuesCleared(message);
} }
@ -557,7 +557,7 @@ public class MapForProto2Test {
CodedOutputStream output = CodedOutputStream.newInstance(byteArrayOutputStream); CodedOutputStream output = CodedOutputStream.newInstance(byteArrayOutputStream);
bizarroMap.writeTo(output); bizarroMap.writeTo(output);
output.flush(); output.flush();
return TestMap.parser().parseFrom(ByteString.copyFrom(byteArrayOutputStream.toByteArray())); return TestMap.parseFrom(ByteString.copyFrom(byteArrayOutputStream.toByteArray()));
} }
@Test @Test

@ -425,21 +425,21 @@ public final class MapLiteTest {
setMapValues(builder); setMapValues(builder);
TestMap message = builder.build(); TestMap message = builder.build();
assertThat(message.toByteString().size()).isEqualTo(message.getSerializedSize()); assertThat(message.toByteString().size()).isEqualTo(message.getSerializedSize());
message = TestMap.parser().parseFrom(message.toByteString()); message = TestMap.parseFrom(message.toByteString());
assertMapValuesSet(message); assertMapValuesSet(message);
builder = message.toBuilder(); builder = message.toBuilder();
updateMapValues(builder); updateMapValues(builder);
message = builder.build(); message = builder.build();
assertThat(message.toByteString().size()).isEqualTo(message.getSerializedSize()); assertThat(message.toByteString().size()).isEqualTo(message.getSerializedSize());
message = TestMap.parser().parseFrom(message.toByteString()); message = TestMap.parseFrom(message.toByteString());
assertMapValuesUpdated(message); assertMapValuesUpdated(message);
builder = message.toBuilder(); builder = message.toBuilder();
builder.clear(); builder.clear();
message = builder.build(); message = builder.build();
assertThat(message.toByteString().size()).isEqualTo(message.getSerializedSize()); assertThat(message.toByteString().size()).isEqualTo(message.getSerializedSize());
message = TestMap.parser().parseFrom(message.toByteString()); message = TestMap.parseFrom(message.toByteString());
assertMapValuesCleared(message); assertMapValuesCleared(message);
} }
@ -448,7 +448,7 @@ public final class MapLiteTest {
CodedOutputStream output = CodedOutputStream.newInstance(byteArrayOutputStream); CodedOutputStream output = CodedOutputStream.newInstance(byteArrayOutputStream);
bizarroMap.writeTo(output); bizarroMap.writeTo(output);
output.flush(); output.flush();
return TestMap.parser().parseFrom(ByteString.copyFrom(byteArrayOutputStream.toByteArray())); return TestMap.parseFrom(ByteString.copyFrom(byteArrayOutputStream.toByteArray()));
} }
@Test @Test

@ -580,21 +580,21 @@ public class MapTest {
setMapValuesUsingAccessors(builder); setMapValuesUsingAccessors(builder);
TestMap message = builder.build(); TestMap message = builder.build();
assertThat(message.toByteString().size()).isEqualTo(message.getSerializedSize()); assertThat(message.toByteString().size()).isEqualTo(message.getSerializedSize());
message = TestMap.parser().parseFrom(message.toByteString()); message = TestMap.parseFrom(message.toByteString());
assertMapValuesSet(message); assertMapValuesSet(message);
builder = message.toBuilder(); builder = message.toBuilder();
updateMapValuesUsingAccessors(builder); updateMapValuesUsingAccessors(builder);
message = builder.build(); message = builder.build();
assertThat(message.toByteString().size()).isEqualTo(message.getSerializedSize()); assertThat(message.toByteString().size()).isEqualTo(message.getSerializedSize());
message = TestMap.parser().parseFrom(message.toByteString()); message = TestMap.parseFrom(message.toByteString());
assertMapValuesUpdated(message); assertMapValuesUpdated(message);
builder = message.toBuilder(); builder = message.toBuilder();
builder.clear(); builder.clear();
message = builder.build(); message = builder.build();
assertThat(message.toByteString().size()).isEqualTo(message.getSerializedSize()); assertThat(message.toByteString().size()).isEqualTo(message.getSerializedSize());
message = TestMap.parser().parseFrom(message.toByteString()); message = TestMap.parseFrom(message.toByteString());
assertMapValuesCleared(message); assertMapValuesCleared(message);
} }
@ -603,7 +603,7 @@ public class MapTest {
CodedOutputStream output = CodedOutputStream.newInstance(byteArrayOutputStream); CodedOutputStream output = CodedOutputStream.newInstance(byteArrayOutputStream);
bizarroMap.writeTo(output); bizarroMap.writeTo(output);
output.flush(); output.flush();
return TestMap.parser().parseFrom(ByteString.copyFrom(byteArrayOutputStream.toByteArray())); return TestMap.parseFrom(ByteString.copyFrom(byteArrayOutputStream.toByteArray()));
} }
@Test @Test

@ -192,7 +192,7 @@ public class ParserLiteTest {
// Parse TestParsingMergeLite. // Parse TestParsingMergeLite.
ExtensionRegistryLite registry = ExtensionRegistryLite.newInstance(); ExtensionRegistryLite registry = ExtensionRegistryLite.newInstance();
UnittestLite.registerAllExtensions(registry); UnittestLite.registerAllExtensions(registry);
TestParsingMergeLite parsingMerge = TestParsingMergeLite.parser().parseFrom(data, registry); TestParsingMergeLite parsingMerge = TestParsingMergeLite.parseFrom(data, registry);
// Required and optional fields should be merged. // Required and optional fields should be merged.
assertMessageMerged(parsingMerge.getRequiredAllTypes()); assertMessageMerged(parsingMerge.getRequiredAllTypes());

@ -195,8 +195,7 @@ public class ParserTest {
@Test @Test
public void testParseUnknownFields() throws Exception { public void testParseUnknownFields() throws Exception {
// All fields will be treated as unknown fields in emptyMessage. // All fields will be treated as unknown fields in emptyMessage.
TestEmptyMessage emptyMessage = TestEmptyMessage emptyMessage = TestEmptyMessage.parseFrom(TestUtil.getAllSet().toByteString());
TestEmptyMessage.parser().parseFrom(TestUtil.getAllSet().toByteString());
assertThat(emptyMessage.toByteString()).isEqualTo(TestUtil.getAllSet().toByteString()); assertThat(emptyMessage.toByteString()).isEqualTo(TestUtil.getAllSet().toByteString());
} }
@ -278,7 +277,7 @@ public class ParserTest {
// Parse TestParsingMerge. // Parse TestParsingMerge.
ExtensionRegistry registry = ExtensionRegistry.newInstance(); ExtensionRegistry registry = ExtensionRegistry.newInstance();
UnittestProto.registerAllExtensions(registry); UnittestProto.registerAllExtensions(registry);
TestParsingMerge parsingMerge = TestParsingMerge.parser().parseFrom(data, registry); TestParsingMerge parsingMerge = TestParsingMerge.parseFrom(data, registry);
// Required and optional fields should be merged. // Required and optional fields should be merged.
assertMessageMerged(parsingMerge.getRequiredAllTypes()); assertMessageMerged(parsingMerge.getRequiredAllTypes());

@ -289,10 +289,6 @@ class UnknownFieldsAccessorsTest(unittest.TestCase):
unknown_field_set = unknown_fields.UnknownFieldSet(destination) unknown_field_set = unknown_fields.UnknownFieldSet(destination)
self.assertEqual(0, len(unknown_field_set)) self.assertEqual(0, len(unknown_field_set))
destination.ParseFromString(message.SerializeToString()) destination.ParseFromString(message.SerializeToString())
# TODO(jieluo): add this back after implement new cpp unknown fields
# b/217277954
if api_implementation.Type() == 'cpp':
return
self.assertEqual(0, len(unknown_field_set)) self.assertEqual(0, len(unknown_field_set))
unknown_field_set = unknown_fields.UnknownFieldSet(destination) unknown_field_set = unknown_fields.UnknownFieldSet(destination)
self.assertEqual(2, len(unknown_field_set)) self.assertEqual(2, len(unknown_field_set))
@ -310,10 +306,6 @@ class UnknownFieldsAccessorsTest(unittest.TestCase):
self.empty_message.Clear() self.empty_message.Clear()
# All cleared, even unknown fields. # All cleared, even unknown fields.
self.assertEqual(self.empty_message.SerializeToString(), b'') self.assertEqual(self.empty_message.SerializeToString(), b'')
# TODO(jieluo): add this back after implement new cpp unknown fields
# b/217277954
if api_implementation.Type() == 'cpp':
return
self.assertEqual(len(unknown_field_set), 97) self.assertEqual(len(unknown_field_set), 97)
@unittest.skipIf((sys.version_info.major, sys.version_info.minor) < (3, 4), @unittest.skipIf((sys.version_info.major, sys.version_info.minor) < (3, 4),
@ -345,10 +337,6 @@ class UnknownFieldsAccessorsTest(unittest.TestCase):
self.assertEqual(1, len(sub_unknown_fields)) self.assertEqual(1, len(sub_unknown_fields))
self.assertEqual(sub_unknown_fields[0].data, 123) self.assertEqual(sub_unknown_fields[0].data, 123)
destination.Clear() destination.Clear()
# TODO(jieluo): add this back after implement new cpp unknown fields
# b/217277954
if api_implementation.Type() == 'cpp':
return
self.assertEqual(1, len(sub_unknown_fields)) self.assertEqual(1, len(sub_unknown_fields))
self.assertEqual(sub_unknown_fields[0].data, 123) self.assertEqual(sub_unknown_fields[0].data, 123)
message.Clear() message.Clear()
@ -372,10 +360,6 @@ class UnknownFieldsAccessorsTest(unittest.TestCase):
destination.ParseFromString(message.SerializeToString()) destination.ParseFromString(message.SerializeToString())
unknown_field = unknown_fields.UnknownFieldSet(destination)[0] unknown_field = unknown_fields.UnknownFieldSet(destination)[0]
destination.Clear() destination.Clear()
# TODO(jieluo): add this back after implement new cpp unknown fields
# b/217277954
if api_implementation.Type() == 'cpp':
return
self.assertEqual(unknown_field.data, 123) self.assertEqual(unknown_field.data, 123)
def testUnknownExtensions(self): def testUnknownExtensions(self):
@ -416,6 +400,7 @@ class UnknownEnumValuesTest(unittest.TestCase):
def CheckUnknownField(self, name, expected_value): def CheckUnknownField(self, name, expected_value):
field_descriptor = self.descriptor.fields_by_name[name] field_descriptor = self.descriptor.fields_by_name[name]
unknown_field_set = unknown_fields.UnknownFieldSet(self.missing_message) unknown_field_set = unknown_fields.UnknownFieldSet(self.missing_message)
self.assertIsInstance(unknown_field_set, unknown_fields.UnknownFieldSet)
count = 0 count = 0
for field in unknown_field_set: for field in unknown_field_set:
if field.field_number == field_descriptor.number: if field.field_number == field_descriptor.number:

@ -868,6 +868,7 @@ class ListValue(object):
collections.abc.MutableSequence.register(ListValue) collections.abc.MutableSequence.register(ListValue)
# LINT.IfChange(wktbases)
WKTBASES = { WKTBASES = {
'google.protobuf.Any': Any, 'google.protobuf.Any': Any,
'google.protobuf.Duration': Duration, 'google.protobuf.Duration': Duration,
@ -876,3 +877,4 @@ WKTBASES = {
'google.protobuf.Struct': Struct, 'google.protobuf.Struct': Struct,
'google.protobuf.Timestamp': Timestamp, 'google.protobuf.Timestamp': Timestamp,
} }
# LINT.ThenChange(//depot/google.protobuf/compiler/python/pyi_generator.cc:wktbases)

@ -68,6 +68,7 @@
#include <google/protobuf/pyext/repeated_scalar_container.h> #include <google/protobuf/pyext/repeated_scalar_container.h>
#include <google/protobuf/pyext/safe_numerics.h> #include <google/protobuf/pyext/safe_numerics.h>
#include <google/protobuf/pyext/scoped_pyobject_ptr.h> #include <google/protobuf/pyext/scoped_pyobject_ptr.h>
#include <google/protobuf/pyext/unknown_field_set.h>
#include <google/protobuf/pyext/unknown_fields.h> #include <google/protobuf/pyext/unknown_fields.h>
#include <google/protobuf/util/message_differencer.h> #include <google/protobuf/util/message_differencer.h>
#include <google/protobuf/io/strtod.h> #include <google/protobuf/io/strtod.h>
@ -2424,7 +2425,7 @@ static PyObject* GetExtensionDict(CMessage* self, void *closure) {
return reinterpret_cast<PyObject*>(extension_dict); return reinterpret_cast<PyObject*>(extension_dict);
} }
static PyObject* UnknownFieldSet(CMessage* self) { static PyObject* GetUnknownFields(CMessage* self) {
if (self->unknown_field_set == nullptr) { if (self->unknown_field_set == nullptr) {
self->unknown_field_set = unknown_fields::NewPyUnknownFields(self); self->unknown_field_set = unknown_fields::NewPyUnknownFields(self);
} else { } else {
@ -2493,7 +2494,7 @@ static PyMethodDef Methods[] = {
"Serializes the message to a string, only for initialized messages."}, "Serializes the message to a string, only for initialized messages."},
{"SetInParent", (PyCFunction)SetInParent, METH_NOARGS, {"SetInParent", (PyCFunction)SetInParent, METH_NOARGS,
"Sets the has bit of the given field in its parent message."}, "Sets the has bit of the given field in its parent message."},
{"UnknownFields", (PyCFunction)UnknownFieldSet, METH_NOARGS, {"UnknownFields", (PyCFunction)GetUnknownFields, METH_NOARGS,
"Parse unknown field set"}, "Parse unknown field set"},
{"WhichOneof", (PyCFunction)WhichOneof, METH_O, {"WhichOneof", (PyCFunction)WhichOneof, METH_O,
"Returns the name of the field set inside a oneof, " "Returns the name of the field set inside a oneof, "
@ -2970,15 +2971,20 @@ bool InitProto2MessageModule(PyObject *m) {
return false; return false;
} }
if (PyType_Ready(&PyUnknownFieldSet_Type) < 0) {
return false;
}
PyModule_AddObject(m, "UnknownFieldSet", PyModule_AddObject(m, "UnknownFieldSet",
reinterpret_cast<PyObject*>(&PyUnknownFields_Type)); reinterpret_cast<PyObject*>(&PyUnknownFieldSet_Type));
if (PyType_Ready(&PyUnknownFieldRef_Type) < 0) { if (PyType_Ready(&PyUnknownFieldRef_Type) < 0) {
return false; return false;
} }
PyModule_AddObject(m, "UnknownField", if (PyType_Ready(&PyUnknownField_Type) < 0) {
reinterpret_cast<PyObject*>(&PyUnknownFieldRef_Type)); return false;
}
// Initialize Map container types. // Initialize Map container types.
if (!InitMapContainers()) { if (!InitMapContainers()) {

@ -0,0 +1,353 @@
// Protocol Buffers - Google's data interchange format
// Copyright 2008 Google Inc. All rights reserved.
// https://developers.google.com/protocol-buffers/
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are
// met:
//
// * Redistributions of source code must retain the above copyright
// notice, this list of conditions and the following disclaimer.
// * Redistributions in binary form must reproduce the above
// copyright notice, this list of conditions and the following disclaimer
// in the documentation and/or other materials provided with the
// distribution.
// * Neither the name of Google Inc. nor the names of its
// contributors may be used to endorse or promote products derived from
// this software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#include <google/protobuf/pyext/unknown_field_set.h>
#define PY_SSIZE_T_CLEAN
#include <Python.h>
#include <memory>
#include <set>
#include <google/protobuf/message.h>
#include <google/protobuf/unknown_field_set.h>
#include <google/protobuf/wire_format_lite.h>
#include <google/protobuf/pyext/message.h>
#include <google/protobuf/pyext/scoped_pyobject_ptr.h>
namespace google {
namespace protobuf {
namespace python {
namespace unknown_field_set {
static Py_ssize_t Len(PyObject* pself) {
PyUnknownFieldSet* self = reinterpret_cast<PyUnknownFieldSet*>(pself);
if (self->fields == nullptr) {
PyErr_Format(PyExc_ValueError, "UnknownFieldSet does not exist. ");
return -1;
}
return self->fields->field_count();
}
PyObject* NewPyUnknownField(PyUnknownFieldSet* parent, Py_ssize_t index);
static PyObject* Item(PyObject* pself, Py_ssize_t index) {
PyUnknownFieldSet* self = reinterpret_cast<PyUnknownFieldSet*>(pself);
if (self->fields == nullptr) {
PyErr_Format(PyExc_ValueError, "UnknownFieldSet does not exist. ");
return nullptr;
}
Py_ssize_t total_size = self->fields->field_count();
if (index < 0) {
index = total_size + index;
}
if (index < 0 || index >= total_size) {
PyErr_Format(PyExc_IndexError, "index (%zd) out of range", index);
return nullptr;
}
return unknown_field_set::NewPyUnknownField(self, index);
}
PyObject* New(PyTypeObject* type, PyObject* args, PyObject* kwargs) {
if (args == nullptr || PyTuple_Size(args) != 1) {
PyErr_SetString(PyExc_TypeError,
"Must provide a message to create UnknownFieldSet");
return nullptr;
}
PyObject* c_message;
if (!PyArg_ParseTuple(args, "O", &c_message)) {
PyErr_SetString(PyExc_TypeError,
"Must provide a message to create UnknownFieldSet");
return nullptr;
}
if (!PyObject_TypeCheck(c_message, CMessage_Type)) {
PyErr_Format(PyExc_TypeError,
"Parameter to UnknownFieldSet() must be a message "
"got %s.",
Py_TYPE(c_message)->tp_name);
return nullptr;
}
PyUnknownFieldSet* self = reinterpret_cast<PyUnknownFieldSet*>(
PyType_GenericAlloc(&PyUnknownFieldSet_Type, 0));
if (self == nullptr) {
return nullptr;
}
// Top UnknownFieldSet should set parent nullptr.
self->parent = nullptr;
// Copy c_message's UnknownFieldSet.
Message* message = reinterpret_cast<CMessage*>(c_message)->message;
const Reflection* reflection = message->GetReflection();
self->fields = new google::protobuf::UnknownFieldSet;
self->fields->MergeFrom(reflection->GetUnknownFields(*message));
return reinterpret_cast<PyObject*>(self);
}
PyObject* NewPyUnknownField(PyUnknownFieldSet* parent, Py_ssize_t index) {
PyUnknownField* self = reinterpret_cast<PyUnknownField*>(
PyType_GenericAlloc(&PyUnknownField_Type, 0));
if (self == nullptr) {
return nullptr;
}
Py_INCREF(parent);
self->parent = parent;
self->index = index;
return reinterpret_cast<PyObject*>(self);
}
static void Dealloc(PyObject* pself) {
PyUnknownFieldSet* self = reinterpret_cast<PyUnknownFieldSet*>(pself);
if (self->parent == nullptr) {
delete self->fields;
}
auto* py_type = Py_TYPE(pself);
self->~PyUnknownFieldSet();
py_type->tp_free(pself);
}
static PySequenceMethods SqMethods = {
Len, /* sq_length */
nullptr, /* sq_concat */
nullptr, /* sq_repeat */
Item, /* sq_item */
nullptr, /* sq_slice */
nullptr, /* sq_ass_item */
};
} // namespace unknown_field_set
PyTypeObject PyUnknownFieldSet_Type = {
PyVarObject_HEAD_INIT(&PyType_Type, 0) FULL_MODULE_NAME
".PyUnknownFieldSet", // tp_name
sizeof(PyUnknownFieldSet), // tp_basicsize
0, // tp_itemsize
unknown_field_set::Dealloc, // tp_dealloc
#if PY_VERSION_HEX < 0x03080000
nullptr, // tp_print
#else
0, // tp_vectorcall_offset
#endif
nullptr, // tp_getattr
nullptr, // tp_setattr
nullptr, // tp_compare
nullptr, // tp_repr
nullptr, // tp_as_number
&unknown_field_set::SqMethods, // tp_as_sequence
nullptr, // tp_as_mapping
PyObject_HashNotImplemented, // tp_hash
nullptr, // tp_call
nullptr, // tp_str
nullptr, // tp_getattro
nullptr, // tp_setattro
nullptr, // tp_as_buffer
Py_TPFLAGS_DEFAULT, // tp_flags
"unknown field set", // tp_doc
nullptr, // tp_traverse
nullptr, // tp_clear
nullptr, // tp_richcompare
0, // tp_weaklistoffset
nullptr, // tp_iter
nullptr, // tp_iternext
nullptr, // tp_methods
nullptr, // tp_members
nullptr, // tp_getset
nullptr, // tp_base
nullptr, // tp_dict
nullptr, // tp_descr_get
nullptr, // tp_descr_set
0, // tp_dictoffset
nullptr, // tp_init
nullptr, // tp_alloc
unknown_field_set::New, // tp_new
};
namespace unknown_field {
static PyObject* PyUnknownFieldSet_FromUnknownFieldSet(
PyUnknownFieldSet* parent, const UnknownFieldSet& fields) {
PyUnknownFieldSet* self = reinterpret_cast<PyUnknownFieldSet*>(
PyType_GenericAlloc(&PyUnknownFieldSet_Type, 0));
if (self == nullptr) {
return nullptr;
}
Py_INCREF(parent);
self->parent = parent;
self->fields = const_cast<UnknownFieldSet*>(&fields);
return reinterpret_cast<PyObject*>(self);
}
const UnknownField* GetUnknownField(PyUnknownField* self) {
const UnknownFieldSet* fields = self->parent->fields;
if (fields == nullptr) {
PyErr_Format(PyExc_ValueError, "UnknownField does not exist. ");
return nullptr;
}
Py_ssize_t total_size = fields->field_count();
if (self->index >= total_size) {
PyErr_Format(PyExc_ValueError, "UnknownField does not exist. ");
return nullptr;
}
return &fields->field(self->index);
}
static PyObject* GetFieldNumber(PyUnknownField* self, void* closure) {
const UnknownField* unknown_field = GetUnknownField(self);
if (unknown_field == nullptr) {
return nullptr;
}
return PyLong_FromLong(unknown_field->number());
}
using internal::WireFormatLite;
static PyObject* GetWireType(PyUnknownField* self, void* closure) {
const UnknownField* unknown_field = GetUnknownField(self);
if (unknown_field == nullptr) {
return nullptr;
}
// Assign a default value to suppress may-uninitialized warnings (errors
// when built in some places).
WireFormatLite::WireType wire_type = WireFormatLite::WIRETYPE_VARINT;
switch (unknown_field->type()) {
case UnknownField::TYPE_VARINT:
wire_type = WireFormatLite::WIRETYPE_VARINT;
break;
case UnknownField::TYPE_FIXED32:
wire_type = WireFormatLite::WIRETYPE_FIXED32;
break;
case UnknownField::TYPE_FIXED64:
wire_type = WireFormatLite::WIRETYPE_FIXED64;
break;
case UnknownField::TYPE_LENGTH_DELIMITED:
wire_type = WireFormatLite::WIRETYPE_LENGTH_DELIMITED;
break;
case UnknownField::TYPE_GROUP:
wire_type = WireFormatLite::WIRETYPE_START_GROUP;
break;
}
return PyLong_FromLong(wire_type);
}
static PyObject* GetData(PyUnknownField* self, void* closure) {
const UnknownField* field = GetUnknownField(self);
if (field == nullptr) {
return nullptr;
}
PyObject* data = nullptr;
switch (field->type()) {
case UnknownField::TYPE_VARINT:
data = PyLong_FromUnsignedLongLong(field->varint());
break;
case UnknownField::TYPE_FIXED32:
data = PyLong_FromUnsignedLong(field->fixed32());
break;
case UnknownField::TYPE_FIXED64:
data = PyLong_FromUnsignedLongLong(field->fixed64());
break;
case UnknownField::TYPE_LENGTH_DELIMITED:
data = PyBytes_FromStringAndSize(field->length_delimited().data(),
field->GetLengthDelimitedSize());
break;
case UnknownField::TYPE_GROUP:
data =
PyUnknownFieldSet_FromUnknownFieldSet(self->parent, field->group());
break;
}
return data;
}
static void Dealloc(PyObject* pself) {
PyUnknownField* self = reinterpret_cast<PyUnknownField*>(pself);
Py_CLEAR(self->parent);
}
static PyGetSetDef Getters[] = {
{"field_number", (getter)GetFieldNumber, nullptr},
{"wire_type", (getter)GetWireType, nullptr},
{"data", (getter)GetData, nullptr},
{nullptr},
};
} // namespace unknown_field
PyTypeObject PyUnknownField_Type = {
PyVarObject_HEAD_INIT(&PyType_Type, 0) FULL_MODULE_NAME
".PyUnknownField", // tp_name
sizeof(PyUnknownField), // tp_basicsize
0, // tp_itemsize
unknown_field::Dealloc, // tp_dealloc
#if PY_VERSION_HEX < 0x03080000
nullptr, // tp_print
#else
0, // tp_vectorcall_offset
#endif
nullptr, // tp_getattr
nullptr, // tp_setattr
nullptr, // tp_compare
nullptr, // tp_repr
nullptr, // tp_as_number
nullptr, // tp_as_sequence
nullptr, // tp_as_mapping
PyObject_HashNotImplemented, // tp_hash
nullptr, // tp_call
nullptr, // tp_str
nullptr, // tp_getattro
nullptr, // tp_setattro
nullptr, // tp_as_buffer
Py_TPFLAGS_DEFAULT, // tp_flags
"unknown field", // tp_doc
nullptr, // tp_traverse
nullptr, // tp_clear
nullptr, // tp_richcompare
0, // tp_weaklistoffset
nullptr, // tp_iter
nullptr, // tp_iternext
nullptr, // tp_methods
nullptr, // tp_members
unknown_field::Getters, // tp_getset
nullptr, // tp_base
nullptr, // tp_dict
nullptr, // tp_descr_get
nullptr, // tp_descr_set
0, // tp_dictoffset
nullptr, // tp_init
};
} // namespace python
} // namespace protobuf
} // namespace google

@ -0,0 +1,78 @@
// Protocol Buffers - Google's data interchange format
// Copyright 2008 Google Inc. All rights reserved.
// https://developers.google.com/protocol-buffers/
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are
// met:
//
// * Redistributions of source code must retain the above copyright
// notice, this list of conditions and the following disclaimer.
// * Redistributions in binary form must reproduce the above
// copyright notice, this list of conditions and the following disclaimer
// in the documentation and/or other materials provided with the
// distribution.
// * Neither the name of Google Inc. nor the names of its
// contributors may be used to endorse or promote products derived from
// this software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#ifndef GOOGLE_PROTOBUF_PYTHON_CPP_UNKNOWN_FIELD_SET_H__
#define GOOGLE_PROTOBUF_PYTHON_CPP_UNKNOWN_FIELD_SET_H__
#define PY_SSIZE_T_CLEAN
#include <Python.h>
#include <memory>
#include <set>
#include <google/protobuf/pyext/message.h>
namespace google {
namespace protobuf {
class UnknownField;
class UnknownFieldSet;
namespace python {
struct CMessage;
struct PyUnknownFieldSet {
PyObject_HEAD;
// If parent is nullptr, it is a top UnknownFieldSet.
PyUnknownFieldSet* parent;
// Top UnknownFieldSet owns fields pointer. Sub UnknownFieldSet
// does not own fields pointer.
UnknownFieldSet* fields;
};
struct PyUnknownField {
PyObject_HEAD;
// Every Python PyUnknownField holds a reference to its parent
// PyUnknownFieldSet in order to keep it alive.
PyUnknownFieldSet* parent;
// The UnknownField index in UnknownFieldSet.
Py_ssize_t index;
};
extern PyTypeObject PyUnknownFieldSet_Type;
extern PyTypeObject PyUnknownField_Type;
} // namespace python
} // namespace protobuf
} // namespace google
#endif // GOOGLE_PROTOBUF_PYTHON_CPP_UNKNOWN_FIELD_SET_H__

@ -40,16 +40,15 @@ Simple usage example:
from google.protobuf.internal import api_implementation from google.protobuf.internal import api_implementation
from google.protobuf.internal import decoder if api_implementation.Type() == 'cpp':
from google.protobuf.internal import wire_format from google.protobuf.pyext import _message # pylint: disable=g-import-not-at-top
else:
from google.protobuf.internal import decoder # pylint: disable=g-import-not-at-top
from google.protobuf.internal import wire_format # pylint: disable=g-import-not-at-top
if api_implementation.Type() == 'cpp': if api_implementation.Type() == 'cpp':
def UnknownFieldSet(msg): UnknownFieldSet = _message.UnknownFieldSet
# New UnknownFieldSet in cpp extension has not implemented yet. Fall
# back to old API
# TODO(jieluo): Add UnknownFieldSet for cpp extension.
return msg.UnknownFields()
else: else:
class UnknownField: class UnknownField:

@ -45,6 +45,9 @@
#include <google/protobuf/compiler/java/java_name_resolver.h> #include <google/protobuf/compiler/java/java_name_resolver.h>
#include <google/protobuf/descriptor.pb.h> #include <google/protobuf/descriptor.pb.h>
// Must be last.
#include <google/protobuf/port_def.inc>
namespace google { namespace google {
namespace protobuf { namespace protobuf {
namespace compiler { namespace compiler {
@ -390,3 +393,5 @@ bool EnumGenerator::CanUseEnumValues() {
} // namespace compiler } // namespace compiler
} // namespace protobuf } // namespace protobuf
} // namespace google } // namespace google
#include <google/protobuf/port_undef.inc>

@ -48,6 +48,9 @@
#include <google/protobuf/compiler/java/java_helpers.h> #include <google/protobuf/compiler/java/java_helpers.h>
#include <google/protobuf/compiler/java/java_name_resolver.h> #include <google/protobuf/compiler/java/java_name_resolver.h>
// Must be last.
#include <google/protobuf/port_def.inc>
namespace google { namespace google {
namespace protobuf { namespace protobuf {
namespace compiler { namespace compiler {
@ -859,7 +862,7 @@ void RepeatedImmutableEnumFieldGenerator::GenerateBuilderMembers(
"}\n"); "}\n");
printer->Annotate("{", "}", descriptor_); printer->Annotate("{", "}", descriptor_);
WriteFieldEnumValueAccessorDocComment(printer, descriptor_, WriteFieldEnumValueAccessorDocComment(printer, descriptor_,
LIST_INDEXED_GETTER, LIST_INDEXED_SETTER,
/* builder */ true); /* builder */ true);
printer->Print( printer->Print(
variables_, variables_,
@ -1174,3 +1177,5 @@ std::string RepeatedImmutableEnumFieldGenerator::GetBoxedType() const {
} // namespace compiler } // namespace compiler
} // namespace protobuf } // namespace protobuf
} // namespace google } // namespace google
#include <google/protobuf/port_undef.inc>

@ -48,6 +48,9 @@
#include <google/protobuf/compiler/java/java_helpers.h> #include <google/protobuf/compiler/java/java_helpers.h>
#include <google/protobuf/compiler/java/java_name_resolver.h> #include <google/protobuf/compiler/java/java_name_resolver.h>
// Must be last.
#include <google/protobuf/port_def.inc>
namespace google { namespace google {
namespace protobuf { namespace protobuf {
namespace compiler { namespace compiler {
@ -916,3 +919,5 @@ std::string RepeatedImmutableEnumFieldLiteGenerator::GetBoxedType() const {
} // namespace compiler } // namespace compiler
} // namespace protobuf } // namespace protobuf
} // namespace google } // namespace google
#include <google/protobuf/port_undef.inc>

@ -41,6 +41,9 @@
#include <google/protobuf/compiler/java/java_helpers.h> #include <google/protobuf/compiler/java/java_helpers.h>
#include <google/protobuf/compiler/java/java_name_resolver.h> #include <google/protobuf/compiler/java/java_name_resolver.h>
// Must be last.
#include <google/protobuf/port_def.inc>
namespace google { namespace google {
namespace protobuf { namespace protobuf {
namespace compiler { namespace compiler {
@ -170,3 +173,5 @@ int ImmutableExtensionGenerator::GenerateRegistrationCode(
} // namespace compiler } // namespace compiler
} // namespace protobuf } // namespace protobuf
} // namespace google } // namespace google
#include <google/protobuf/port_undef.inc>

@ -37,6 +37,9 @@
#include <google/protobuf/compiler/java/java_helpers.h> #include <google/protobuf/compiler/java/java_helpers.h>
#include <google/protobuf/compiler/java/java_name_resolver.h> #include <google/protobuf/compiler/java/java_name_resolver.h>
// Must be last.
#include <google/protobuf/port_def.inc>
namespace google { namespace google {
namespace protobuf { namespace protobuf {
namespace compiler { namespace compiler {
@ -113,3 +116,5 @@ int ImmutableExtensionLiteGenerator::GenerateRegistrationCode(
} // namespace compiler } // namespace compiler
} // namespace protobuf } // namespace protobuf
} // namespace google } // namespace google
#include <google/protobuf/port_undef.inc>

@ -54,6 +54,9 @@
#include <google/protobuf/compiler/java/java_shared_code_generator.h> #include <google/protobuf/compiler/java/java_shared_code_generator.h>
#include <google/protobuf/descriptor.pb.h> #include <google/protobuf/descriptor.pb.h>
// Must be last.
#include <google/protobuf/port_def.inc>
namespace google { namespace google {
namespace protobuf { namespace protobuf {
namespace compiler { namespace compiler {
@ -732,3 +735,5 @@ bool FileGenerator::ShouldIncludeDependency(const FileDescriptor* descriptor,
} // namespace compiler } // namespace compiler
} // namespace protobuf } // namespace protobuf
} // namespace google } // namespace google
#include <google/protobuf/port_undef.inc>

@ -49,6 +49,9 @@
#include <google/protobuf/descriptor.pb.h> #include <google/protobuf/descriptor.pb.h>
#include <google/protobuf/stubs/hash.h> // for hash<T *> #include <google/protobuf/stubs/hash.h> // for hash<T *>
// Must be last.
#include <google/protobuf/port_def.inc>
namespace google { namespace google {
namespace protobuf { namespace protobuf {
namespace compiler { namespace compiler {
@ -1109,3 +1112,5 @@ void EscapeUtf16ToString(uint16_t code, std::string* output) {
} // namespace compiler } // namespace compiler
} // namespace protobuf } // namespace protobuf
} // namespace google } // namespace google
#include <google/protobuf/port_undef.inc>

@ -36,6 +36,9 @@
#include <google/protobuf/compiler/java/java_helpers.h> #include <google/protobuf/compiler/java/java_helpers.h>
#include <google/protobuf/compiler/java/java_name_resolver.h> #include <google/protobuf/compiler/java/java_name_resolver.h>
// Must be last.
#include <google/protobuf/port_def.inc>
namespace google { namespace google {
namespace protobuf { namespace protobuf {
namespace compiler { namespace compiler {
@ -887,3 +890,5 @@ std::string ImmutableMapFieldGenerator::GetBoxedType() const {
} // namespace compiler } // namespace compiler
} // namespace protobuf } // namespace protobuf
} // namespace google } // namespace google
#include <google/protobuf/port_undef.inc>

@ -38,6 +38,9 @@
#include <google/protobuf/compiler/java/java_helpers.h> #include <google/protobuf/compiler/java/java_helpers.h>
#include <google/protobuf/compiler/java/java_name_resolver.h> #include <google/protobuf/compiler/java/java_name_resolver.h>
// Must be last.
#include <google/protobuf/port_def.inc>
namespace google { namespace google {
namespace protobuf { namespace protobuf {
namespace compiler { namespace compiler {
@ -920,3 +923,5 @@ std::string ImmutableMapFieldLiteGenerator::GetBoxedType() const {
} // namespace compiler } // namespace compiler
} // namespace protobuf } // namespace protobuf
} // namespace google } // namespace google
#include <google/protobuf/port_undef.inc>

@ -56,6 +56,9 @@
#include <google/protobuf/compiler/java/java_name_resolver.h> #include <google/protobuf/compiler/java/java_name_resolver.h>
#include <google/protobuf/descriptor.pb.h> #include <google/protobuf/descriptor.pb.h>
// Must be last.
#include <google/protobuf/port_def.inc>
namespace google { namespace google {
namespace protobuf { namespace protobuf {
namespace compiler { namespace compiler {
@ -1749,3 +1752,5 @@ void ImmutableMessageGenerator::GenerateAnyMethods(io::Printer* printer) {
} // namespace compiler } // namespace compiler
} // namespace protobuf } // namespace protobuf
} // namespace google } // namespace google
#include <google/protobuf/port_undef.inc>

@ -53,6 +53,9 @@
#include <google/protobuf/compiler/java/java_name_resolver.h> #include <google/protobuf/compiler/java/java_name_resolver.h>
#include <google/protobuf/descriptor.pb.h> #include <google/protobuf/descriptor.pb.h>
// Must be last.
#include <google/protobuf/port_def.inc>
namespace google { namespace google {
namespace protobuf { namespace protobuf {
namespace compiler { namespace compiler {
@ -710,3 +713,5 @@ void MessageBuilderGenerator::GenerateIsInitialized(io::Printer* printer) {
} // namespace compiler } // namespace compiler
} // namespace protobuf } // namespace protobuf
} // namespace google } // namespace google
#include <google/protobuf/port_undef.inc>

@ -53,6 +53,9 @@
#include <google/protobuf/compiler/java/java_name_resolver.h> #include <google/protobuf/compiler/java/java_name_resolver.h>
#include <google/protobuf/descriptor.pb.h> #include <google/protobuf/descriptor.pb.h>
// Must be last.
#include <google/protobuf/port_def.inc>
namespace google { namespace google {
namespace protobuf { namespace protobuf {
namespace compiler { namespace compiler {
@ -149,3 +152,5 @@ void MessageBuilderLiteGenerator::GenerateCommonBuilderMethods(
} // namespace compiler } // namespace compiler
} // namespace protobuf } // namespace protobuf
} // namespace google } // namespace google
#include <google/protobuf/port_undef.inc>

@ -45,6 +45,9 @@
#include <google/protobuf/compiler/java/java_helpers.h> #include <google/protobuf/compiler/java/java_helpers.h>
#include <google/protobuf/compiler/java/java_name_resolver.h> #include <google/protobuf/compiler/java/java_name_resolver.h>
// Must be last.
#include <google/protobuf/port_def.inc>
namespace google { namespace google {
namespace protobuf { namespace protobuf {
namespace compiler { namespace compiler {
@ -1504,3 +1507,5 @@ void RepeatedImmutableMessageFieldGenerator::GenerateKotlinDslMembers(
} // namespace compiler } // namespace compiler
} // namespace protobuf } // namespace protobuf
} // namespace google } // namespace google
#include <google/protobuf/port_undef.inc>

@ -46,6 +46,9 @@
#include <google/protobuf/compiler/java/java_helpers.h> #include <google/protobuf/compiler/java/java_helpers.h>
#include <google/protobuf/compiler/java/java_name_resolver.h> #include <google/protobuf/compiler/java/java_name_resolver.h>
// Must be last.
#include <google/protobuf/port_def.inc>
namespace google { namespace google {
namespace protobuf { namespace protobuf {
namespace compiler { namespace compiler {
@ -891,3 +894,5 @@ void RepeatedImmutableMessageFieldLiteGenerator::GenerateKotlinDslMembers(
} // namespace compiler } // namespace compiler
} // namespace protobuf } // namespace protobuf
} // namespace google } // namespace google
#include <google/protobuf/port_undef.inc>

@ -56,6 +56,9 @@
#include <google/protobuf/compiler/java/java_name_resolver.h> #include <google/protobuf/compiler/java/java_name_resolver.h>
#include <google/protobuf/descriptor.pb.h> #include <google/protobuf/descriptor.pb.h>
// Must be last.
#include <google/protobuf/port_def.inc>
namespace google { namespace google {
namespace protobuf { namespace protobuf {
namespace compiler { namespace compiler {
@ -996,3 +999,5 @@ void ImmutableMessageLiteGenerator::GenerateKotlinExtensions(
} // namespace compiler } // namespace compiler
} // namespace protobuf } // namespace protobuf
} // namespace google } // namespace google
#include <google/protobuf/port_undef.inc>

@ -38,6 +38,9 @@
#include <google/protobuf/compiler/java/java_helpers.h> #include <google/protobuf/compiler/java/java_helpers.h>
#include <google/protobuf/compiler/java/java_names.h> #include <google/protobuf/compiler/java/java_names.h>
// Must be last.
#include <google/protobuf/port_def.inc>
namespace google { namespace google {
namespace protobuf { namespace protobuf {
namespace compiler { namespace compiler {
@ -378,3 +381,5 @@ std::string ClassNameResolver::GetDowngradedClassName(
} // namespace compiler } // namespace compiler
} // namespace protobuf } // namespace protobuf
} // namespace google } // namespace google
#include <google/protobuf/port_undef.inc>

@ -36,6 +36,9 @@
#include <google/protobuf/stubs/common.h> #include <google/protobuf/stubs/common.h>
// Must be last.
#include <google/protobuf/port_def.inc>
namespace google { namespace google {
namespace protobuf { namespace protobuf {
class Descriptor; class Descriptor;
@ -151,4 +154,6 @@ class ClassNameResolver {
} // namespace protobuf } // namespace protobuf
} // namespace google } // namespace google
#include <google/protobuf/port_undef.inc>
#endif // GOOGLE_PROTOBUF_COMPILER_JAVA_NAME_RESOLVER_H__ #endif // GOOGLE_PROTOBUF_COMPILER_JAVA_NAME_RESOLVER_H__

@ -41,6 +41,9 @@
#include <google/protobuf/compiler/java/java_helpers.h> #include <google/protobuf/compiler/java/java_helpers.h>
#include <google/protobuf/compiler/java/java_name_resolver.h> #include <google/protobuf/compiler/java/java_name_resolver.h>
// Must be last.
#include <google/protobuf/port_def.inc>
namespace google { namespace google {
namespace protobuf { namespace protobuf {
namespace compiler { namespace compiler {
@ -472,3 +475,5 @@ void ImmutableServiceGenerator::GenerateBlockingMethodSignature(
} // namespace compiler } // namespace compiler
} // namespace protobuf } // namespace protobuf
} // namespace google } // namespace google
#include <google/protobuf/port_undef.inc>

@ -62,7 +62,7 @@ const char* const kKeywords[] = {
"del", "elif", "else", "except", "finally", "for", "del", "elif", "else", "except", "finally", "for",
"from", "global", "if", "import", "in", "is", "from", "global", "if", "import", "in", "is",
"lambda", "nonlocal", "not", "or", "pass", "raise", "lambda", "nonlocal", "not", "or", "pass", "raise",
"return", "try", "while", "with", "yield", "print", "return", "try", "while", "with", "yield",
}; };
const char* const* kKeywordsEnd = const char* const* kKeywordsEnd =
kKeywords + (sizeof(kKeywords) / sizeof(kKeywords[0])); kKeywords + (sizeof(kKeywords) / sizeof(kKeywords[0]));

@ -64,12 +64,21 @@ void PyiGenerator::PrintItemMap(
} }
template <typename DescriptorT> template <typename DescriptorT>
std::string PyiGenerator::ModuleLevelName(const DescriptorT& descriptor) const { std::string PyiGenerator::ModuleLevelName(
const DescriptorT& descriptor,
const std::map<std::string, std::string>& import_map) const {
std::string name = NamePrefixedWithNestedTypes(descriptor, "."); std::string name = NamePrefixedWithNestedTypes(descriptor, ".");
if (descriptor.file() != file_) { if (descriptor.file() != file_) {
std::string module_name = ModuleName(descriptor.file()->name()); std::string module_alias;
std::vector<std::string> tokens = Split(module_name, "."); std::string filename = descriptor.file()->name();
name = "_" + tokens.back() + "." + name; if (import_map.find(filename) == import_map.end()) {
std::string module_name = ModuleName(descriptor.file()->name());
std::vector<std::string> tokens = Split(module_name, ".");
module_alias = "_" + tokens.back();
} else {
module_alias = import_map.at(filename);
}
name = module_alias + "." + name;
} }
return name; return name;
} }
@ -82,9 +91,22 @@ struct ImportModules {
bool has_extendable = false; // _python_message bool has_extendable = false; // _python_message
bool has_mapping = false; // typing.Mapping bool has_mapping = false; // typing.Mapping
bool has_optional = false; // typing.Optional bool has_optional = false; // typing.Optional
bool has_union = false; // typing.Uion bool has_union = false; // typing.Union
bool has_well_known_type = false;
}; };
// Checks whether a descriptor name matches a well-known type.
bool IsWellKnownType(const std::string& name) {
// LINT.IfChange(wktbases)
return (name == "google.protobuf.Any" ||
name == "google.protobuf.Duration" ||
name == "google.protobuf.FieldMask" ||
name == "google.protobuf.ListValue" ||
name == "google.protobuf.Struct" ||
name == "google.protobuf.Timestamp");
// LINT.ThenChange(//depot/google3/net/proto2/python/internal/well_known_types.py:wktbases)
}
// Checks what modules should be imported for this message // Checks what modules should be imported for this message
// descriptor. // descriptor.
void CheckImportModules(const Descriptor* descriptor, void CheckImportModules(const Descriptor* descriptor,
@ -95,6 +117,9 @@ void CheckImportModules(const Descriptor* descriptor,
if (descriptor->enum_type_count() > 0) { if (descriptor->enum_type_count() > 0) {
import_modules->has_enums = true; import_modules->has_enums = true;
} }
if (IsWellKnownType(descriptor->full_name())) {
import_modules->has_well_known_type = true;
}
for (int i = 0; i < descriptor->field_count(); ++i) { for (int i = 0; i < descriptor->field_count(); ++i) {
const FieldDescriptor* field = descriptor->field(i); const FieldDescriptor* field = descriptor->field(i);
if (IsPythonKeyword(field->name())) { if (IsPythonKeyword(field->name())) {
@ -129,23 +154,44 @@ void CheckImportModules(const Descriptor* descriptor,
} }
} }
void PyiGenerator::PrintImportForDescriptor(
const FileDescriptor& desc,
std::map<std::string, std::string>* import_map,
std::set<std::string>* seen_aliases) const {
const std::string& filename = desc.name();
std::string module_name = StrippedModuleName(filename);
size_t last_dot_pos = module_name.rfind('.');
std::string import_statement;
if (last_dot_pos == std::string::npos) {
import_statement = "import " + module_name;
} else {
import_statement = "from " + module_name.substr(0, last_dot_pos) +
" import " + module_name.substr(last_dot_pos + 1);
module_name = module_name.substr(last_dot_pos + 1);
}
std::string alias = "_" + module_name;
// Generate a unique alias by adding _1 suffixes until we get an unused alias.
while (seen_aliases->find(alias) != seen_aliases->end()) {
alias = alias + "_1";
}
printer_->Print("$statement$ as $alias$\n", "statement",
import_statement, "alias", alias);
(*import_map)[filename] = alias;
seen_aliases->insert(alias);
}
void PyiGenerator::PrintImports( void PyiGenerator::PrintImports(
std::map<std::string, std::string>* item_map) const { std::map<std::string, std::string>* item_map,
std::map<std::string, std::string>* import_map) const {
// Prints imported dependent _pb2 files. // Prints imported dependent _pb2 files.
std::set<std::string> seen_aliases;
for (int i = 0; i < file_->dependency_count(); ++i) { for (int i = 0; i < file_->dependency_count(); ++i) {
const std::string& filename = file_->dependency(i)->name(); const FileDescriptor* dep = file_->dependency(i);
std::string module_name = StrippedModuleName(filename); PrintImportForDescriptor(*dep, import_map, &seen_aliases);
size_t last_dot_pos = module_name.rfind('.'); for (int j = 0; j < dep->public_dependency_count(); ++j) {
std::string import_statement; PrintImportForDescriptor(
if (last_dot_pos == std::string::npos) { *dep->public_dependency(j), import_map, &seen_aliases);
import_statement = "import " + module_name;
} else {
import_statement = "from " + module_name.substr(0, last_dot_pos) +
" import " + module_name.substr(last_dot_pos + 1);
module_name = module_name.substr(last_dot_pos + 1);
} }
printer_->Print("$statement$ as _$module_name$\n", "statement",
import_statement, "module_name", module_name);
} }
// Checks what modules should be imported. // Checks what modules should be imported.
@ -177,6 +223,11 @@ void PyiGenerator::PrintImports(
"from google.protobuf.internal import python_message" "from google.protobuf.internal import python_message"
" as _python_message\n"); " as _python_message\n");
} }
if (import_modules.has_well_known_type) {
printer_->Print(
"from google.protobuf.internal import well_known_types"
" as _well_known_types\n");
}
printer_->Print( printer_->Print(
"from google.protobuf import" "from google.protobuf import"
" descriptor as _descriptor\n"); " descriptor as _descriptor\n");
@ -190,21 +241,18 @@ void PyiGenerator::PrintImports(
" _service\n"); " _service\n");
} }
printer_->Print("from typing import "); printer_->Print("from typing import ");
printer_->Print("ClassVar"); printer_->Print("ClassVar as _ClassVar");
if (import_modules.has_iterable) { if (import_modules.has_iterable) {
printer_->Print(", Iterable"); printer_->Print(", Iterable as _Iterable");
} }
if (import_modules.has_mapping) { if (import_modules.has_mapping) {
printer_->Print(", Mapping"); printer_->Print(", Mapping as _Mapping");
} }
if (import_modules.has_optional) { if (import_modules.has_optional) {
printer_->Print(", Optional"); printer_->Print(", Optional as _Optional");
}
if (file_->service_count() > 0) {
printer_->Print(", Text");
} }
if (import_modules.has_union) { if (import_modules.has_union) {
printer_->Print(", Union"); printer_->Print(", Union as _Union");
} }
printer_->Print("\n\n"); printer_->Print("\n\n");
@ -229,7 +277,7 @@ void PyiGenerator::PrintImports(
const EnumDescriptor* enum_descriptor = public_dep->enum_type(i); const EnumDescriptor* enum_descriptor = public_dep->enum_type(i);
for (int j = 0; j < enum_descriptor->value_count(); ++j) { for (int j = 0; j < enum_descriptor->value_count(); ++j) {
(*item_map)[enum_descriptor->value(j)->name()] = (*item_map)[enum_descriptor->value(j)->name()] =
ModuleLevelName(*enum_descriptor); ModuleLevelName(*enum_descriptor, *import_map);
} }
} }
// Top level extensions for public imports // Top level extensions for public imports
@ -248,9 +296,10 @@ void PyiGenerator::PrintEnum(const EnumDescriptor& enum_descriptor) const {
// Adds enum value to item map which will be ordered and printed later. // Adds enum value to item map which will be ordered and printed later.
void PyiGenerator::AddEnumValue( void PyiGenerator::AddEnumValue(
const EnumDescriptor& enum_descriptor, const EnumDescriptor& enum_descriptor,
std::map<std::string, std::string>* item_map) const { std::map<std::string, std::string>* item_map,
const std::map<std::string, std::string>& import_map) const {
// enum values // enum values
std::string module_enum_name = ModuleLevelName(enum_descriptor); std::string module_enum_name = ModuleLevelName(enum_descriptor, import_map);
for (int j = 0; j < enum_descriptor.value_count(); ++j) { for (int j = 0; j < enum_descriptor.value_count(); ++j) {
const EnumValueDescriptor* value_descriptor = enum_descriptor.value(j); const EnumValueDescriptor* value_descriptor = enum_descriptor.value(j);
(*item_map)[value_descriptor->name()] = module_enum_name; (*item_map)[value_descriptor->name()] = module_enum_name;
@ -275,13 +324,15 @@ void PyiGenerator::AddExtensions(
const FieldDescriptor* extension_field = descriptor.extension(i); const FieldDescriptor* extension_field = descriptor.extension(i);
std::string constant_name = extension_field->name() + "_FIELD_NUMBER"; std::string constant_name = extension_field->name() + "_FIELD_NUMBER";
ToUpper(&constant_name); ToUpper(&constant_name);
(*item_map)[constant_name] = "ClassVar[int]"; (*item_map)[constant_name] = "_ClassVar[int]";
(*item_map)[extension_field->name()] = "_descriptor.FieldDescriptor"; (*item_map)[extension_field->name()] = "_descriptor.FieldDescriptor";
} }
} }
// Returns the string format of a field's cpp_type // Returns the string format of a field's cpp_type
std::string PyiGenerator::GetFieldType(const FieldDescriptor& field_des) const { std::string PyiGenerator::GetFieldType(
const FieldDescriptor& field_des, const Descriptor& containing_des,
const std::map<std::string, std::string>& import_map) const {
switch (field_des.cpp_type()) { switch (field_des.cpp_type()) {
case FieldDescriptor::CPPTYPE_INT32: case FieldDescriptor::CPPTYPE_INT32:
case FieldDescriptor::CPPTYPE_UINT32: case FieldDescriptor::CPPTYPE_UINT32:
@ -294,29 +345,48 @@ std::string PyiGenerator::GetFieldType(const FieldDescriptor& field_des) const {
case FieldDescriptor::CPPTYPE_BOOL: case FieldDescriptor::CPPTYPE_BOOL:
return "bool"; return "bool";
case FieldDescriptor::CPPTYPE_ENUM: case FieldDescriptor::CPPTYPE_ENUM:
return ModuleLevelName(*field_des.enum_type()); return ModuleLevelName(*field_des.enum_type(), import_map);
case FieldDescriptor::CPPTYPE_STRING: case FieldDescriptor::CPPTYPE_STRING:
if (field_des.type() == FieldDescriptor::TYPE_STRING) { if (field_des.type() == FieldDescriptor::TYPE_STRING) {
return "str"; return "str";
} else { } else {
return "bytes"; return "bytes";
} }
case FieldDescriptor::CPPTYPE_MESSAGE: case FieldDescriptor::CPPTYPE_MESSAGE: {
return ModuleLevelName(*field_des.message_type()); // If the field is inside a nested message and the nested message has the
// same name as a top-level message, then we need to prefix the field type
// with the module name for disambiguation.
std::string name = ModuleLevelName(*field_des.message_type(), import_map);
if ((containing_des.containing_type() != nullptr &&
name == containing_des.name())) {
std::string module = ModuleName(field_des.file()->name());
name = module + "." + name;
}
return name;
}
default: default:
GOOGLE_LOG(FATAL) << "Unsuppoted field type."; GOOGLE_LOG(FATAL) << "Unsupported field type.";
} }
return ""; return "";
} }
void PyiGenerator::PrintMessage(const Descriptor& message_descriptor, void PyiGenerator::PrintMessage(
bool is_nested) const { const Descriptor& message_descriptor, bool is_nested,
const std::map<std::string, std::string>& import_map) const {
if (!is_nested) { if (!is_nested) {
printer_->Print("\n"); printer_->Print("\n");
} }
std::string class_name = message_descriptor.name(); std::string class_name = message_descriptor.name();
printer_->Print("class $class_name$(_message.Message):\n", "class_name", std::string extra_base;
class_name); // A well-known type needs to inherit from its corresponding base class in
// net/proto2/python/internal/well_known_types.
if (IsWellKnownType(message_descriptor.full_name())) {
extra_base = ", _well_known_types." + message_descriptor.name();
} else {
extra_base = "";
}
printer_->Print("class $class_name$(_message.Message$extra_base$):\n",
"class_name", class_name, "extra_base", extra_base);
printer_->Indent(); printer_->Indent();
printer_->Indent(); printer_->Indent();
@ -361,7 +431,7 @@ void PyiGenerator::PrintMessage(const Descriptor& message_descriptor,
for (const auto& entry : nested_enums) { for (const auto& entry : nested_enums) {
PrintEnum(*entry); PrintEnum(*entry);
// Adds enum value to item_map which will be ordered and printed later // Adds enum value to item_map which will be ordered and printed later
AddEnumValue(*entry, &item_map); AddEnumValue(*entry, &item_map, import_map);
} }
// Prints nested messages // Prints nested messages
@ -374,7 +444,7 @@ void PyiGenerator::PrintMessage(const Descriptor& message_descriptor,
SortByName<Descriptor>()); SortByName<Descriptor>());
for (const auto& entry : nested_messages) { for (const auto& entry : nested_messages) {
PrintMessage(*entry, true); PrintMessage(*entry, true, import_map);
} }
// Adds extensions to item_map which will be ordered and printed later // Adds extensions to item_map which will be ordered and printed later
@ -384,7 +454,7 @@ void PyiGenerator::PrintMessage(const Descriptor& message_descriptor,
for (int i = 0; i < message_descriptor.field_count(); ++i) { for (int i = 0; i < message_descriptor.field_count(); ++i) {
const FieldDescriptor& field_des = *message_descriptor.field(i); const FieldDescriptor& field_des = *message_descriptor.field(i);
item_map[ToUpper(field_des.name()) + "_FIELD_NUMBER"] = item_map[ToUpper(field_des.name()) + "_FIELD_NUMBER"] =
"ClassVar[int]"; "_ClassVar[int]";
if (IsPythonKeyword(field_des.name())) { if (IsPythonKeyword(field_des.name())) {
continue; continue;
} }
@ -395,16 +465,16 @@ void PyiGenerator::PrintMessage(const Descriptor& message_descriptor,
field_type = (value_des->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE field_type = (value_des->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE
? "_containers.MessageMap[" ? "_containers.MessageMap["
: "_containers.ScalarMap["); : "_containers.ScalarMap[");
field_type += GetFieldType(*key_des); field_type += GetFieldType(*key_des, message_descriptor, import_map);
field_type += ", "; field_type += ", ";
field_type += GetFieldType(*value_des); field_type += GetFieldType(*value_des, message_descriptor, import_map);
} else { } else {
if (field_des.is_repeated()) { if (field_des.is_repeated()) {
field_type = (field_des.cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE field_type = (field_des.cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE
? "_containers.RepeatedCompositeFieldContainer[" ? "_containers.RepeatedCompositeFieldContainer["
: "_containers.RepeatedScalarFieldContainer["); : "_containers.RepeatedScalarFieldContainer[");
} }
field_type += GetFieldType(field_des); field_type += GetFieldType(field_des, message_descriptor, import_map);
} }
if (field_des.is_repeated()) { if (field_des.is_repeated()) {
@ -437,26 +507,31 @@ void PyiGenerator::PrintMessage(const Descriptor& message_descriptor,
printer_->Print(", $field_name$: ", "field_name", field_name); printer_->Print(", $field_name$: ", "field_name", field_name);
if (field_des->is_repeated() || if (field_des->is_repeated() ||
field_des->cpp_type() != FieldDescriptor::CPPTYPE_BOOL) { field_des->cpp_type() != FieldDescriptor::CPPTYPE_BOOL) {
printer_->Print("Optional["); printer_->Print("_Optional[");
} }
if (field_des->is_map()) { if (field_des->is_map()) {
const Descriptor* map_entry = field_des->message_type(); const Descriptor* map_entry = field_des->message_type();
printer_->Print("Mapping[$key_type$, $value_type$]", "key_type", printer_->Print(
GetFieldType(*map_entry->field(0)), "value_type", "_Mapping[$key_type$, $value_type$]", "key_type",
GetFieldType(*map_entry->field(1))); GetFieldType(*map_entry->field(0), message_descriptor, import_map),
"value_type",
GetFieldType(*map_entry->field(1), message_descriptor, import_map));
} else { } else {
if (field_des->is_repeated()) { if (field_des->is_repeated()) {
printer_->Print("Iterable["); printer_->Print("_Iterable[");
} }
if (field_des->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE) { if (field_des->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE) {
printer_->Print("Union[$type_name$, Mapping]", "type_name", printer_->Print(
GetFieldType(*field_des)); "_Union[$type_name$, _Mapping]", "type_name",
GetFieldType(*field_des, message_descriptor, import_map));
} else { } else {
if (field_des->cpp_type() == FieldDescriptor::CPPTYPE_ENUM) { if (field_des->cpp_type() == FieldDescriptor::CPPTYPE_ENUM) {
printer_->Print("Union[$type_name$, str]", "type_name", printer_->Print("_Union[$type_name$, str]", "type_name",
ModuleLevelName(*field_des->enum_type())); ModuleLevelName(*field_des->enum_type(), import_map));
} else { } else {
printer_->Print("$type_name$", "type_name", GetFieldType(*field_des)); printer_->Print(
"$type_name$", "type_name",
GetFieldType(*field_des, message_descriptor, import_map));
} }
} }
if (field_des->is_repeated()) { if (field_des->is_repeated()) {
@ -478,7 +553,8 @@ void PyiGenerator::PrintMessage(const Descriptor& message_descriptor,
printer_->Outdent(); printer_->Outdent();
} }
void PyiGenerator::PrintMessages() const { void PyiGenerator::PrintMessages(
const std::map<std::string, std::string>& import_map) const {
// Order the descriptors by name to have same output with proto_to_pyi.py // Order the descriptors by name to have same output with proto_to_pyi.py
std::vector<const Descriptor*> messages; std::vector<const Descriptor*> messages;
messages.reserve(file_->message_type_count()); messages.reserve(file_->message_type_count());
@ -488,7 +564,7 @@ void PyiGenerator::PrintMessages() const {
std::sort(messages.begin(), messages.end(), SortByName<Descriptor>()); std::sort(messages.begin(), messages.end(), SortByName<Descriptor>());
for (const auto& entry : messages) { for (const auto& entry : messages) {
PrintMessage(*entry, false); PrintMessage(*entry, false, import_map);
} }
} }
@ -534,17 +610,22 @@ bool PyiGenerator::Generate(const FileDescriptor* file,
// Adds "DESCRIPTOR" into item_map. // Adds "DESCRIPTOR" into item_map.
item_map["DESCRIPTOR"] = "_descriptor.FileDescriptor"; item_map["DESCRIPTOR"] = "_descriptor.FileDescriptor";
PrintImports(&item_map);
// import_map will be a mapping from filename to module alias, e.g.
// "google3/foo/bar.py" -> "_bar"
std::map<std::string, std::string> import_map;
PrintImports(&item_map, &import_map);
// Adds top level enum values to item_map. // Adds top level enum values to item_map.
for (int i = 0; i < file_->enum_type_count(); ++i) { for (int i = 0; i < file_->enum_type_count(); ++i) {
AddEnumValue(*file_->enum_type(i), &item_map); AddEnumValue(*file_->enum_type(i), &item_map, import_map);
} }
// Adds top level extensions to item_map. // Adds top level extensions to item_map.
AddExtensions(*file_, &item_map); AddExtensions(*file_, &item_map);
// Prints item map // Prints item map
PrintItemMap(item_map); PrintItemMap(item_map);
PrintMessages(); PrintMessages(import_map);
PrintTopLevelEnums(); PrintTopLevelEnums();
if (HasGenericServices(file)) { if (HasGenericServices(file)) {
PrintServices(); PrintServices();

@ -36,6 +36,7 @@
#define GOOGLE_PROTOBUF_COMPILER_PYTHON_PYI_GENERATOR_H__ #define GOOGLE_PROTOBUF_COMPILER_PYTHON_PYI_GENERATOR_H__
#include <map> #include <map>
#include <set>
#include <string> #include <string>
#include <google/protobuf/stubs/mutex.h> #include <google/protobuf/stubs/mutex.h>
@ -65,26 +66,41 @@ class PROTOC_EXPORT PyiGenerator : public google::protobuf::compiler::CodeGenera
~PyiGenerator() override; ~PyiGenerator() override;
// CodeGenerator methods. // CodeGenerator methods.
uint64_t GetSupportedFeatures() const override {
// Code generators must explicitly support proto3 optional.
return CodeGenerator::FEATURE_PROTO3_OPTIONAL;
}
bool Generate(const FileDescriptor* file, const std::string& parameter, bool Generate(const FileDescriptor* file, const std::string& parameter,
GeneratorContext* generator_context, GeneratorContext* generator_context,
std::string* error) const override; std::string* error) const override;
private: private:
void PrintImports(std::map<std::string, std::string>* item_map) const; void PrintImportForDescriptor(const FileDescriptor& desc,
std::map<std::string, std::string>* import_map,
std::set<std::string>* seen_aliases) const;
void PrintImports(std::map<std::string, std::string>* item_map,
std::map<std::string, std::string>* import_map) const;
void PrintEnum(const EnumDescriptor& enum_descriptor) const; void PrintEnum(const EnumDescriptor& enum_descriptor) const;
void AddEnumValue(const EnumDescriptor& enum_descriptor, void AddEnumValue(const EnumDescriptor& enum_descriptor,
std::map<std::string, std::string>* item_map) const; std::map<std::string, std::string>* item_map,
const std::map<std::string, std::string>& import_map) const;
void PrintTopLevelEnums() const; void PrintTopLevelEnums() const;
template <typename DescriptorT> template <typename DescriptorT>
void AddExtensions(const DescriptorT& descriptor, void AddExtensions(const DescriptorT& descriptor,
std::map<std::string, std::string>* item_map) const; std::map<std::string, std::string>* item_map) const;
void PrintMessages() const; void PrintMessages(
void PrintMessage(const Descriptor& message_descriptor, bool is_nested) const; const std::map<std::string, std::string>& import_map) const;
void PrintMessage(const Descriptor& message_descriptor, bool is_nested,
const std::map<std::string, std::string>& import_map) const;
void PrintServices() const; void PrintServices() const;
void PrintItemMap(const std::map<std::string, std::string>& item_map) const; void PrintItemMap(const std::map<std::string, std::string>& item_map) const;
std::string GetFieldType(const FieldDescriptor& field_des) const; std::string GetFieldType(
const FieldDescriptor& field_des, const Descriptor& containing_des,
const std::map<std::string, std::string>& import_map) const;
template <typename DescriptorT> template <typename DescriptorT>
std::string ModuleLevelName(const DescriptorT& descriptor) const; std::string ModuleLevelName(
const DescriptorT& descriptor,
const std::map<std::string, std::string>& import_map) const;
// Very coarse-grained lock to ensure that Generate() is reentrant. // Very coarse-grained lock to ensure that Generate() is reentrant.
// Guards file_ and printer_. // Guards file_ and printer_.

@ -580,6 +580,35 @@ const char* TcParser::FastF64P2(PROTOBUF_TC_PARAM_DECL) {
namespace { namespace {
// Shift "byte" left by n * 7 bits, filling vacated bits with ones.
template <int n>
inline PROTOBUF_ALWAYS_INLINE uint64_t
shift_left_fill_with_ones(uint64_t byte, uint64_t ones) {
return (byte << (n * 7)) | (ones >> (64 - (n * 7)));
}
// Shift "byte" left by n * 7 bits, filling vacated bits with ones, and
// put the new value in res. Return whether the result was negative.
template <int n>
inline PROTOBUF_ALWAYS_INLINE bool shift_left_fill_with_ones_was_negative(
uint64_t byte, uint64_t ones, int64_t& res) {
#if defined(__GCC_ASM_FLAG_OUTPUTS__) && defined(__x86_64__)
// For the first two rounds (ptr[1] and ptr[2]), micro benchmarks show a
// substantial improvement from capturing the sign from the condition code
// register on x86-64.
bool sign_bit;
asm("shldq %3, %2, %1"
: "=@ccs"(sign_bit), "+r"(byte)
: "r"(ones), "i"(n * 7));
res = byte;
return sign_bit;
#else
// Generic fallback:
res = (byte << (n * 7)) | (ones >> (64 - (n * 7)));
return static_cast<int64_t>(res) < 0;
#endif
}
inline PROTOBUF_ALWAYS_INLINE std::pair<const char*, uint64_t> inline PROTOBUF_ALWAYS_INLINE std::pair<const char*, uint64_t>
Parse64FallbackPair(const char* p, int64_t res1) { Parse64FallbackPair(const char* p, int64_t res1) {
auto ptr = reinterpret_cast<const int8_t*>(p); auto ptr = reinterpret_cast<const int8_t*>(p);
@ -601,78 +630,42 @@ Parse64FallbackPair(const char* p, int64_t res1) {
// has 57 high bits of ones, which is enough for the largest shift done. // has 57 high bits of ones, which is enough for the largest shift done.
GOOGLE_DCHECK_EQ(res1 >> 7, -1); GOOGLE_DCHECK_EQ(res1 >> 7, -1);
uint64_t ones = res1; // save the high 1 bits from res1 (input to SHLD) uint64_t ones = res1; // save the high 1 bits from res1 (input to SHLD)
uint64_t byte; // the "next" 7-bit chunk, shifted (result from SHLD)
int64_t res2, res3; // accumulated result chunks int64_t res2, res3; // accumulated result chunks
#define SHLD(n) byte = ((byte << (n * 7)) | (ones >> (64 - (n * 7))))
int sign_bit;
#if defined(__GCC_ASM_FLAG_OUTPUTS__) && defined(__x86_64__)
// For the first two rounds (ptr[1] and ptr[2]), micro benchmarks show a
// substantial improvement from capturing the sign from the condition code
// register on x86-64.
#define SHLD_SIGN(n) \
asm("shldq %3, %2, %1" \
: "=@ccs"(sign_bit), "+r"(byte) \
: "r"(ones), "i"(n * 7))
#else
// Generic fallback:
#define SHLD_SIGN(n) \
do { \
SHLD(n); \
sign_bit = static_cast<int64_t>(byte) < 0; \
} while (0)
#endif
byte = ptr[1]; if (!shift_left_fill_with_ones_was_negative<1>(ptr[1], ones, res2))
SHLD_SIGN(1); goto done2;
res2 = byte; if (!shift_left_fill_with_ones_was_negative<2>(ptr[2], ones, res3))
if (!sign_bit) goto done2; goto done3;
byte = ptr[2];
SHLD_SIGN(2);
res3 = byte;
if (!sign_bit) goto done3;
#undef SHLD_SIGN
// For the remainder of the chunks, check the sign of the AND result. // For the remainder of the chunks, check the sign of the AND result.
byte = ptr[3]; res1 &= shift_left_fill_with_ones<3>(ptr[3], ones);
SHLD(3);
res1 &= byte;
if (res1 >= 0) goto done4; if (res1 >= 0) goto done4;
byte = ptr[4]; res2 &= shift_left_fill_with_ones<4>(ptr[4], ones);
SHLD(4);
res2 &= byte;
if (res2 >= 0) goto done5; if (res2 >= 0) goto done5;
byte = ptr[5]; res3 &= shift_left_fill_with_ones<5>(ptr[5], ones);
SHLD(5);
res3 &= byte;
if (res3 >= 0) goto done6; if (res3 >= 0) goto done6;
byte = ptr[6]; res1 &= shift_left_fill_with_ones<6>(ptr[6], ones);
SHLD(6);
res1 &= byte;
if (res1 >= 0) goto done7; if (res1 >= 0) goto done7;
byte = ptr[7]; res2 &= shift_left_fill_with_ones<7>(ptr[7], ones);
SHLD(7);
res2 &= byte;
if (res2 >= 0) goto done8; if (res2 >= 0) goto done8;
byte = ptr[8]; res3 &= shift_left_fill_with_ones<8>(ptr[8], ones);
SHLD(8);
res3 &= byte;
if (res3 >= 0) goto done9; if (res3 >= 0) goto done9;
#undef SHLD
// For valid 64bit varints, the 10th byte/ptr[9] should be exactly 1. In this // For valid 64bit varints, the 10th byte/ptr[9] should be exactly 1. In this
// case, the continuation bit of ptr[8] already set the top bit of res3 // case, the continuation bit of ptr[8] already set the top bit of res3
// correctly, so all we have to do is check that the expected case is true. // correctly, so all we have to do is check that the expected case is true.
byte = ptr[9]; if (PROTOBUF_PREDICT_TRUE(ptr[9] == 1)) goto done10;
if (PROTOBUF_PREDICT_TRUE(byte == 1)) goto done10;
// A value of 0, however, represents an over-serialized varint. This case // A value of 0, however, represents an over-serialized varint. This case
// should not happen, but if does (say, due to a nonconforming serializer), // should not happen, but if does (say, due to a nonconforming serializer),
// deassert the continuation bit that came from ptr[8]. // deassert the continuation bit that came from ptr[8].
if (byte == 0) { if (ptr[9] == 0) {
#if defined(__GCC_ASM_FLAG_OUTPUTS__) && defined(__x86_64__)
// Use a small instruction since this is an uncommon code path.
asm("btcq $63,%0" : "+r"(res3));
#else
res3 ^= static_cast<uint64_t>(1) << 63; res3 ^= static_cast<uint64_t>(1) << 63;
#endif
goto done10; goto done10;
} }
@ -680,18 +673,24 @@ Parse64FallbackPair(const char* p, int64_t res1) {
// fit in 64 bits. If the continue bit is set, it is an unterminated varint. // fit in 64 bits. If the continue bit is set, it is an unterminated varint.
return {nullptr, 0}; return {nullptr, 0};
#define DONE(n) done##n : return {p + n, res1 & res2 & res3};
done2: done2:
return {p + 2, res1 & res2}; return {p + 2, res1 & res2};
DONE(3) done3:
DONE(4) return {p + 3, res1 & res2 & res3};
DONE(5) done4:
DONE(6) return {p + 4, res1 & res2 & res3};
DONE(7) done5:
DONE(8) return {p + 5, res1 & res2 & res3};
DONE(9) done6:
DONE(10) return {p + 6, res1 & res2 & res3};
#undef DONE done7:
return {p + 7, res1 & res2 & res3};
done8:
return {p + 8, res1 & res2 & res3};
done9:
return {p + 9, res1 & res2 & res3};
done10:
return {p + 10, res1 & res2 & res3};
} }
inline PROTOBUF_ALWAYS_INLINE const char* ParseVarint(const char* p, inline PROTOBUF_ALWAYS_INLINE const char* ParseVarint(const char* p,

@ -150,12 +150,32 @@ CHARACTER_CLASS(Escape, c == 'a' || c == 'b' || c == 'f' || c == 'n' ||
// Given a char, interpret it as a numeric digit and return its value. // Given a char, interpret it as a numeric digit and return its value.
// This supports any number base up to 36. // This supports any number base up to 36.
inline int DigitValue(char digit) { // Represents integer values of digits.
if ('0' <= digit && digit <= '9') return digit - '0'; // Uses 36 to indicate an invalid character since we support
if ('a' <= digit && digit <= 'z') return digit - 'a' + 10; // bases up to 36.
if ('A' <= digit && digit <= 'Z') return digit - 'A' + 10; static const int8_t kAsciiToInt[256] = {
return -1; 36, 36, 36, 36, 36, 36, 36, 36, 36, 36, 36, 36, 36, 36, 36, 36, // 00-0F
} 36, 36, 36, 36, 36, 36, 36, 36, 36, 36, 36, 36, 36, 36, 36, 36, // 10-1F
36, 36, 36, 36, 36, 36, 36, 36, 36, 36, 36, 36, 36, 36, 36, 36, // ' '-'/'
0, 1, 2, 3, 4, 5, 6, 7, 8, 9, // '0'-'9'
36, 36, 36, 36, 36, 36, 36, // ':'-'@'
10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, // 'A'-'P'
26, 27, 28, 29, 30, 31, 32, 33, 34, 35, // 'Q'-'Z'
36, 36, 36, 36, 36, 36, // '['-'`'
10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, // 'a'-'p'
26, 27, 28, 29, 30, 31, 32, 33, 34, 35, // 'q'-'z'
36, 36, 36, 36, 36, // '{'-DEL
36, 36, 36, 36, 36, 36, 36, 36, 36, 36, 36, 36, 36, 36, 36, 36, // 80-8F
36, 36, 36, 36, 36, 36, 36, 36, 36, 36, 36, 36, 36, 36, 36, 36, // 90-9F
36, 36, 36, 36, 36, 36, 36, 36, 36, 36, 36, 36, 36, 36, 36, 36, // A0-AF
36, 36, 36, 36, 36, 36, 36, 36, 36, 36, 36, 36, 36, 36, 36, 36, // B0-BF
36, 36, 36, 36, 36, 36, 36, 36, 36, 36, 36, 36, 36, 36, 36, 36, // C0-CF
36, 36, 36, 36, 36, 36, 36, 36, 36, 36, 36, 36, 36, 36, 36, 36, // D0-DF
36, 36, 36, 36, 36, 36, 36, 36, 36, 36, 36, 36, 36, 36, 36, 36, // E0-EF
36, 36, 36, 36, 36, 36, 36, 36, 36, 36, 36, 36, 36, 36, 36, 36, // F0-FF
};
inline int DigitValue(char digit) { return kAsciiToInt[digit & 0xFF]; }
// Inline because it's only used in one place. // Inline because it's only used in one place.
inline char TranslateEscape(char c) { inline char TranslateEscape(char c) {
@ -914,25 +934,49 @@ bool Tokenizer::NextWithComments(std::string* prev_trailing_comments,
bool Tokenizer::ParseInteger(const std::string& text, uint64_t max_value, bool Tokenizer::ParseInteger(const std::string& text, uint64_t max_value,
uint64_t* output) { uint64_t* output) {
// Sadly, we can't just use strtoul() since it is only 32-bit and strtoull() // We can't just use strtoull() because (a) it accepts negative numbers,
// is non-standard. I hate the C standard library. :( // (b) We want additional range checks, (c) it reports overflows via errno.
// return strtoull(text.c_str(), NULL, 0); #if 0
const char *str_begin = text.c_str();
if (*str_begin == '-') return false;
char *str_end = nullptr;
errno = 0;
*output = std::strtoull(str_begin, &str_end, 0);
return (errno == 0 && str_end && *str_end == '\0' && *output <= max_value);
#endif
const char* ptr = text.c_str(); const char* ptr = text.c_str();
int base = 10; int base = 10;
uint64_t overflow_if_mul_base = (kuint64max / 10) + 1;
if (ptr[0] == '0') { if (ptr[0] == '0') {
if (ptr[1] == 'x' || ptr[1] == 'X') { if (ptr[1] == 'x' || ptr[1] == 'X') {
// This is hex. // This is hex.
base = 16; base = 16;
overflow_if_mul_base = (kuint64max / 16) + 1;
ptr += 2; ptr += 2;
} else { } else {
// This is octal. // This is octal.
base = 8; base = 8;
overflow_if_mul_base = (kuint64max / 8) + 1;
} }
} }
uint64_t result = 0; uint64_t result = 0;
// For all the leading '0's, and also the first non-zero character, we
// don't need to multiply.
while (*ptr != '\0') {
int digit = DigitValue(*ptr++);
if (digit >= base) {
// The token provided by Tokenizer is invalid. i.e., 099 is an invalid
// token, but Tokenizer still think it's integer.
return false;
}
if (digit != 0) {
result = digit;
break;
}
}
for (; *ptr != '\0'; ptr++) { for (; *ptr != '\0'; ptr++) {
int digit = DigitValue(*ptr); int digit = DigitValue(*ptr);
if (digit < 0 || digit >= base) { if (digit < 0 || digit >= base) {
@ -940,24 +984,18 @@ bool Tokenizer::ParseInteger(const std::string& text, uint64_t max_value,
// token, but Tokenizer still think it's integer. // token, but Tokenizer still think it's integer.
return false; return false;
} }
if (static_cast<uint64_t>(digit) > max_value) return false; if (result >= overflow_if_mul_base) {
#if PROTOBUF_HAS_BUILTIN_MUL_OVERFLOW // We know the multiply we're about to do will overflow, so exit now.
// If there is a uint64_t overflow, there is a result * base logical
// overflow. This is done to avoid division.
if (__builtin_mul_overflow(result, base, &result) ||
result > (max_value - digit)) {
// Overflow.
return false;
}
result += digit;
#else
if (result > (max_value - digit) / base) {
// Overflow.
return false; return false;
} }
// We know that result * base won't overflow, but adding digit might...
result = result * base + digit; result = result * base + digit;
#endif // PROTOBUF_HAS_BUILTIN_MUL_OVERFLOW // C++ guarantees defined "wrap" semantics when unsigned integer
// operations overflow, making this a fast way to check if adding
// digit made result overflow, and thus, wrap around.
if (result < static_cast<uint64_t>(base)) return false;
} }
if (result > max_value) return false;
*output = result; *output = result;
return true; return true;
@ -1199,4 +1237,3 @@ bool Tokenizer::IsIdentifier(const std::string& text) {
} // namespace google } // namespace google
#include <google/protobuf/port_undef.inc> #include <google/protobuf/port_undef.inc>

@ -178,9 +178,10 @@ const int kBlockSizes[] = {1, 2, 3, 5, 7, 13, 32, 1024};
class TokenizerTest : public testing::Test { class TokenizerTest : public testing::Test {
protected: protected:
// For easy testing. // For easy testing.
uint64 ParseInteger(const std::string& text) { uint64_t ParseInteger(const std::string& text) {
uint64 result; uint64_t result;
EXPECT_TRUE(Tokenizer::ParseInteger(text, kuint64max, &result)); EXPECT_TRUE(Tokenizer::ParseInteger(text, kuint64max, &result))
<< "'" << text << "'";
return result; return result;
} }
}; };
@ -809,8 +810,8 @@ TEST_2D(TokenizerTest, DocComments, kDocCommentCases, kBlockSizes) {
// ------------------------------------------------------------------- // -------------------------------------------------------------------
// Test parse helpers. It's not really worth setting up a full data-driven // Test parse helpers.
// test here. // TODO(b/225783758): Add a fuzz test for this.
TEST_F(TokenizerTest, ParseInteger) { TEST_F(TokenizerTest, ParseInteger) {
EXPECT_EQ(0, ParseInteger("0")); EXPECT_EQ(0, ParseInteger("0"));
EXPECT_EQ(123, ParseInteger("123")); EXPECT_EQ(123, ParseInteger("123"));
@ -823,7 +824,7 @@ TEST_F(TokenizerTest, ParseInteger) {
// Test invalid integers that may still be tokenized as integers. // Test invalid integers that may still be tokenized as integers.
EXPECT_EQ(0, ParseInteger("0x")); EXPECT_EQ(0, ParseInteger("0x"));
uint64 i; uint64_t i;
// Test invalid integers that will never be tokenized as integers. // Test invalid integers that will never be tokenized as integers.
EXPECT_FALSE(Tokenizer::ParseInteger("zxy", kuint64max, &i)); EXPECT_FALSE(Tokenizer::ParseInteger("zxy", kuint64max, &i));
@ -840,6 +841,105 @@ TEST_F(TokenizerTest, ParseInteger) {
EXPECT_FALSE(Tokenizer::ParseInteger("12346", 12345, &i)); EXPECT_FALSE(Tokenizer::ParseInteger("12346", 12345, &i));
EXPECT_TRUE(Tokenizer::ParseInteger("0xFFFFFFFFFFFFFFFF", kuint64max, &i)); EXPECT_TRUE(Tokenizer::ParseInteger("0xFFFFFFFFFFFFFFFF", kuint64max, &i));
EXPECT_FALSE(Tokenizer::ParseInteger("0x10000000000000000", kuint64max, &i)); EXPECT_FALSE(Tokenizer::ParseInteger("0x10000000000000000", kuint64max, &i));
// Test near the limits of signed parsing (values in kint64max +/- 1600)
for (int64_t offset = -1600; offset <= 1600; ++offset) {
uint64_t i = 0x7FFFFFFFFFFFFFFF + offset;
char decimal[32];
snprintf(decimal, 32, "%llu", static_cast<unsigned long long>(i));
if (offset > 0) {
uint64_t parsed = -1;
EXPECT_FALSE(Tokenizer::ParseInteger(decimal, kint64max, &parsed))
<< decimal << "=>" << parsed;
} else {
uint64_t parsed = -1;
EXPECT_TRUE(Tokenizer::ParseInteger(decimal, kint64max, &parsed))
<< decimal << "=>" << parsed;
EXPECT_EQ(parsed, i);
}
char octal[32];
snprintf(octal, 32, "0%llo", static_cast<unsigned long long>(i));
if (offset > 0) {
uint64_t parsed = -1;
EXPECT_FALSE(Tokenizer::ParseInteger(octal, kint64max, &parsed))
<< octal << "=>" << parsed;
} else {
uint64_t parsed = -1;
EXPECT_TRUE(Tokenizer::ParseInteger(octal, kint64max, &parsed))
<< octal << "=>" << parsed;
EXPECT_EQ(parsed, i);
}
char hex[32];
snprintf(hex, 32, "0x%llx", static_cast<unsigned long long>(i));
if (offset > 0) {
uint64_t parsed = -1;
EXPECT_FALSE(Tokenizer::ParseInteger(hex, kint64max, &parsed))
<< hex << "=>" << parsed;
} else {
uint64_t parsed = -1;
EXPECT_TRUE(Tokenizer::ParseInteger(hex, kint64max, &parsed)) << hex;
EXPECT_EQ(parsed, i);
}
// EXPECT_NE(offset, -237);
}
// Test near the limits of unsigned parsing (values in kuint64max +/- 1600)
// By definition, values greater than kuint64max cannot be held in a uint64_t
// variable, so printing them is a little tricky; fortunately all but the
// last four digits are known, so we can hard-code them in the printf string,
// and we only need to format the last 4.
for (int64_t offset = -1600; offset <= 1600; ++offset) {
{
uint64_t i = 18446744073709551615u + offset;
char decimal[32];
snprintf(decimal, 32, "1844674407370955%04llu",
static_cast<unsigned long long>(1615 + offset));
if (offset > 0) {
uint64_t parsed = -1;
EXPECT_FALSE(Tokenizer::ParseInteger(decimal, kuint64max, &parsed))
<< decimal << "=>" << parsed;
} else {
uint64_t parsed = -1;
EXPECT_TRUE(Tokenizer::ParseInteger(decimal, kuint64max, &parsed))
<< decimal;
EXPECT_EQ(parsed, i);
}
}
{
uint64_t i = 01777777777777777777777u + offset;
if (offset > 0) {
char octal[32];
snprintf(octal, 32, "0200000000000000000%04llo",
static_cast<unsigned long long>(offset - 1));
uint64_t parsed = -1;
EXPECT_FALSE(Tokenizer::ParseInteger(octal, kuint64max, &parsed))
<< octal << "=>" << parsed;
} else {
char octal[32];
snprintf(octal, 32, "0%llo", static_cast<unsigned long long>(i));
uint64_t parsed = -1;
EXPECT_TRUE(Tokenizer::ParseInteger(octal, kuint64max, &parsed))
<< octal;
EXPECT_EQ(parsed, i);
}
}
{
uint64_t ui = 0xffffffffffffffffu + offset;
char hex[32];
if (offset > 0) {
snprintf(hex, 32, "0x1000000000000%04llx",
static_cast<unsigned long long>(offset - 1));
uint64_t parsed = -1;
EXPECT_FALSE(Tokenizer::ParseInteger(hex, kuint64max, &parsed))
<< hex << "=>" << parsed;
} else {
snprintf(hex, 32, "0x%llx", static_cast<unsigned long long>(ui));
uint64_t parsed = -1;
EXPECT_TRUE(Tokenizer::ParseInteger(hex, kuint64max, &parsed)) << hex;
EXPECT_EQ(parsed, ui);
}
}
}
} }
TEST_F(TokenizerTest, ParseFloat) { TEST_F(TokenizerTest, ParseFloat) {

@ -535,17 +535,6 @@
#define PROTOBUF_ASSUME(pred) GOOGLE_DCHECK(pred) #define PROTOBUF_ASSUME(pred) GOOGLE_DCHECK(pred)
#endif #endif
// PROTOBUF_HAS_BUILTIN_MUL_OVERFLOW tells the compiler if it has
// __builtin_mul_overflow intrinsic to check for multiplication overflow.
#ifdef PROTOBUF_HAS_BUILTIN_MUL_OVERFLOW
#error PROTOBUF_HAS_BUILTIN_MUL_OVERFLOW was previously defined
#endif
#if __has_builtin(__builtin_mul_overflow)
#define PROTOBUF_HAS_BUILTIN_MUL_OVERFLOW 1
#else
#define PROTOBUF_HAS_BUILTIN_MUL_OVERFLOW 0
#endif
// Specify memory alignment for structs, classes, etc. // Specify memory alignment for structs, classes, etc.
// Use like: // Use like:
// class PROTOBUF_ALIGNAS(16) MyClass { ... } // class PROTOBUF_ALIGNAS(16) MyClass { ... }
@ -782,6 +771,8 @@
#undef ERROR_INSTALL_FAILED #undef ERROR_INSTALL_FAILED
#pragma push_macro("ERROR_NOT_FOUND") #pragma push_macro("ERROR_NOT_FOUND")
#undef ERROR_NOT_FOUND #undef ERROR_NOT_FOUND
#pragma push_macro("GetClassName")
#undef GetClassName
#pragma push_macro("GetMessage") #pragma push_macro("GetMessage")
#undef GetMessage #undef GetMessage
#pragma push_macro("IGNORE") #pragma push_macro("IGNORE")

@ -72,7 +72,6 @@
#undef PROTOBUF_NAMESPACE_CLOSE #undef PROTOBUF_NAMESPACE_CLOSE
#undef PROTOBUF_UNUSED #undef PROTOBUF_UNUSED
#undef PROTOBUF_ASSUME #undef PROTOBUF_ASSUME
#undef PROTOBUF_HAS_BUILTIN_MUL_OVERFLOW
#undef PROTOBUF_EXPORT_TEMPLATE_DECLARE #undef PROTOBUF_EXPORT_TEMPLATE_DECLARE
#undef PROTOBUF_EXPORT_TEMPLATE_DEFINE #undef PROTOBUF_EXPORT_TEMPLATE_DEFINE
#undef PROTOBUF_ALIGNAS #undef PROTOBUF_ALIGNAS
@ -112,6 +111,7 @@
#pragma pop_macro("ERROR_BUSY") #pragma pop_macro("ERROR_BUSY")
#pragma pop_macro("ERROR_INSTALL_FAILED") #pragma pop_macro("ERROR_INSTALL_FAILED")
#pragma pop_macro("ERROR_NOT_FOUND") #pragma pop_macro("ERROR_NOT_FOUND")
#pragma pop_macro("GetClassName")
#pragma pop_macro("GetMessage") #pragma pop_macro("GetMessage")
#pragma pop_macro("IGNORE") #pragma pop_macro("IGNORE")
#pragma pop_macro("IN") #pragma pop_macro("IN")

@ -749,7 +749,7 @@ class GenericTypeHandler {
static inline GenericType* New(Arena* arena, GenericType&& value) { static inline GenericType* New(Arena* arena, GenericType&& value) {
return Arena::Create<GenericType>(arena, std::move(value)); return Arena::Create<GenericType>(arena, std::move(value));
} }
static inline GenericType* NewFromPrototype(const GenericType* prototype, static inline GenericType* NewFromPrototype(const GenericType* /*prototype*/,
Arena* arena = nullptr) { Arena* arena = nullptr) {
return New(arena); return New(arena);
} }

@ -803,7 +803,7 @@ class TextFormat::Parser::ParserImpl {
case FieldDescriptor::CPPTYPE_STRING: { case FieldDescriptor::CPPTYPE_STRING: {
std::string value; std::string value;
DO(ConsumeString(&value)); DO(ConsumeString(&value));
SET_FIELD(String, value); SET_FIELD(String, std::move(value));
break; break;
} }

Loading…
Cancel
Save