Breaking Change: Made text_format output default to UTF-8.

Also hardened the text format printer against invalid UTF-8 in string fields.  The output string will always be valid UTF-8, even if string fields contain invalid UTF-8.

PiperOrigin-RevId: 600990001
pull/15539/head
Joshua Haberman 10 months ago committed by Copybara-Service
parent b9e4894462
commit bf00034493
  1. 10
      python/google/protobuf/internal/text_encoding_test.py
  2. 52
      python/google/protobuf/internal/text_format_test.py
  3. 70
      python/google/protobuf/text_encoding.py
  4. 14
      python/google/protobuf/text_format.py

@ -22,17 +22,17 @@ TEST_VALUES = [
"signi\\\\fying\\\\ nothing\\\\", "signi\\\\fying\\\\ nothing\\\\",
b"signi\\fying\\ nothing\\"), b"signi\\fying\\ nothing\\"),
("\\010\\t\\n\\013\\014\\r", ("\\010\\t\\n\\013\\014\\r",
"\x08\\t\\n\x0b\x0c\\r", "\\010\\t\\n\\013\\014\\r",
b"\010\011\012\013\014\015")] b"\010\011\012\013\014\015")]
class TextEncodingTestCase(unittest.TestCase): class TextEncodingTestCase(unittest.TestCase):
def testCEscape(self): def testCEscape(self):
for escaped, escaped_utf8, unescaped in TEST_VALUES: for escaped, escaped_utf8, unescaped in TEST_VALUES:
self.assertEqual(escaped, self.assertEqual(escaped, text_encoding.CEscape(unescaped, as_utf8=False))
text_encoding.CEscape(unescaped, as_utf8=False)) self.assertEqual(
self.assertEqual(escaped_utf8, escaped_utf8, text_encoding.CEscape(unescaped, as_utf8=True)
text_encoding.CEscape(unescaped, as_utf8=True)) )
def testCUnescape(self): def testCUnescape(self):
for escaped, escaped_utf8, unescaped in TEST_VALUES: for escaped, escaped_utf8, unescaped in TEST_VALUES:

@ -86,7 +86,9 @@ class TextFormatMessageToStringTests(TextFormatBase):
message.repeated_string.append('\000\001\a\b\f\n\r\t\v\\\'"') message.repeated_string.append('\000\001\a\b\f\n\r\t\v\\\'"')
message.repeated_string.append(u'\u00fc\ua71f') message.repeated_string.append(u'\u00fc\ua71f')
self.CompareToGoldenText( self.CompareToGoldenText(
self.RemoveRedundantZeros(text_format.MessageToString(message)), self.RemoveRedundantZeros(
text_format.MessageToString(message, as_utf8=True)
),
'repeated_int64: -9223372036854775808\n' 'repeated_int64: -9223372036854775808\n'
'repeated_uint64: 18446744073709551615\n' 'repeated_uint64: 18446744073709551615\n'
'repeated_double: 123.456\n' 'repeated_double: 123.456\n'
@ -94,7 +96,8 @@ class TextFormatMessageToStringTests(TextFormatBase):
'repeated_double: 1.23e-18\n' 'repeated_double: 1.23e-18\n'
'repeated_string:' 'repeated_string:'
' "\\000\\001\\007\\010\\014\\n\\r\\t\\013\\\\\\\'\\""\n' ' "\\000\\001\\007\\010\\014\\n\\r\\t\\013\\\\\\\'\\""\n'
'repeated_string: "\\303\\274\\352\\234\\237"\n') 'repeated_string: "üꜟ"\n',
)
def testPrintFloatPrecision(self, message_module): def testPrintFloatPrecision(self, message_module):
message = message_module.TestAllTypes() message = message_module.TestAllTypes()
@ -204,8 +207,8 @@ class TextFormatMessageToStringTests(TextFormatBase):
message = message_module.TestAllTypes() message = message_module.TestAllTypes()
message.repeated_string.append(UnicodeSub(u'\u00fc\ua71f')) message.repeated_string.append(UnicodeSub(u'\u00fc\ua71f'))
self.CompareToGoldenText( self.CompareToGoldenText(
text_format.MessageToString(message), text_format.MessageToString(message, as_utf8=True),
'repeated_string: "\\303\\274\\352\\234\\237"\n') 'repeated_string: "üꜟ"\n')
def testPrintNestedMessageAsOneLine(self, message_module): def testPrintNestedMessageAsOneLine(self, message_module):
message = message_module.TestAllTypes() message = message_module.TestAllTypes()
@ -282,7 +285,7 @@ class TextFormatMessageToStringTests(TextFormatBase):
message.repeated_string.append(u'\u00fc\ua71f') message.repeated_string.append(u'\u00fc\ua71f')
self.CompareToGoldenText( self.CompareToGoldenText(
self.RemoveRedundantZeros(text_format.MessageToString( self.RemoveRedundantZeros(text_format.MessageToString(
message, as_one_line=True)), message, as_one_line=True, as_utf8=True)),
'repeated_int64: -9223372036854775808' 'repeated_int64: -9223372036854775808'
' repeated_uint64: 18446744073709551615' ' repeated_uint64: 18446744073709551615'
' repeated_double: 123.456' ' repeated_double: 123.456'
@ -290,7 +293,7 @@ class TextFormatMessageToStringTests(TextFormatBase):
' repeated_double: 1.23e-18' ' repeated_double: 1.23e-18'
' repeated_string: ' ' repeated_string: '
'"\\000\\001\\007\\010\\014\\n\\r\\t\\013\\\\\\\'\\""' '"\\000\\001\\007\\010\\014\\n\\r\\t\\013\\\\\\\'\\""'
' repeated_string: "\\303\\274\\352\\234\\237"') ' repeated_string: "üꜟ"')
def testRoundTripExoticAsOneLine(self, message_module): def testRoundTripExoticAsOneLine(self, message_module):
message = message_module.TestAllTypes() message = message_module.TestAllTypes()
@ -616,8 +619,8 @@ class TextFormatMessageToTextBytesTests(TextFormatBase):
def testRawUtf8RoundTrip(self, message_module): def testRawUtf8RoundTrip(self, message_module):
message = message_module.TestAllTypes() message = message_module.TestAllTypes()
message.repeated_string.append(u'\u00fc\t\ua71f') message.repeated_string.append(u'\u00fc\t\ua71f')
utf8_text = text_format.MessageToBytes(message, as_utf8=True) utf8_text = text_format.MessageToBytes(message, as_utf8=False)
golden_bytes = b'repeated_string: "\xc3\xbc\\t\xea\x9c\x9f"\n' golden_bytes = b'repeated_string: "\\303\\274\\t\\352\\234\\237"\n'
self.CompareToGoldenText(utf8_text, golden_bytes) self.CompareToGoldenText(utf8_text, golden_bytes)
parsed_message = message_module.TestAllTypes() parsed_message = message_module.TestAllTypes()
text_format.Parse(utf8_text, parsed_message) text_format.Parse(utf8_text, parsed_message)
@ -626,10 +629,41 @@ class TextFormatMessageToTextBytesTests(TextFormatBase):
(message, parsed_message, message.repeated_string[0], (message, parsed_message, message.repeated_string[0],
parsed_message.repeated_string[0])) parsed_message.repeated_string[0]))
def testRawUtf8RoundTripAsUtf8(self, message_module):
message = message_module.TestAllTypes()
message.repeated_string.append(u'\u00fc\t\ua71f')
utf8_text = text_format.MessageToString(message, as_utf8=True)
parsed_message = message_module.TestAllTypes()
text_format.Parse(utf8_text, parsed_message)
self.assertEqual(
message, parsed_message, '\n%s != %s (%s != %s)' %
(message, parsed_message, message.repeated_string[0],
parsed_message.repeated_string[0]))
# We can only test this case under proto2, because proto3 will reject invalid
# UTF-8 in the parser, so there should be no way of creating a string field
# that contains invalid UTF-8.
#
# We also can't test it in pure-Python, which validates all string fields for
# UTF-8 even when the spec says it shouldn't.
@unittest.skipIf(api_implementation.Type() == 'python',
'Python can\'t create invalid UTF-8 strings')
def testInvalidUtf8RoundTrip(self, message_module):
if message_module is not unittest_pb2:
return
one_bytes = unittest_pb2.OneBytes()
one_bytes.data = b'ABC\xff123'
one_string = unittest_pb2.OneString()
one_string.ParseFromString(one_bytes.SerializeToString())
self.assertIn(
'data: "ABC\\377123"',
text_format.MessageToString(one_string, as_utf8=True),
)
def testEscapedUtf8ASCIIRoundTrip(self, message_module): def testEscapedUtf8ASCIIRoundTrip(self, message_module):
message = message_module.TestAllTypes() message = message_module.TestAllTypes()
message.repeated_string.append(u'\u00fc\t\ua71f') message.repeated_string.append(u'\u00fc\t\ua71f')
ascii_text = text_format.MessageToBytes(message) # as_utf8=False default ascii_text = text_format.MessageToBytes(message, as_utf8=False)
golden_bytes = b'repeated_string: "\\303\\274\\t\\352\\234\\237"\n' golden_bytes = b'repeated_string: "\\303\\274\\t\\352\\234\\237"\n'
self.CompareToGoldenText(ascii_text, golden_bytes) self.CompareToGoldenText(ascii_text, golden_bytes)
parsed_message = message_module.TestAllTypes() parsed_message = message_module.TestAllTypes()

@ -8,26 +8,42 @@
"""Encoding related utilities.""" """Encoding related utilities."""
import re import re
_cescape_chr_to_symbol_map = {} def _AsciiIsPrint(i):
_cescape_chr_to_symbol_map[9] = r'\t' # optional escape return i >= 32 and i < 127
_cescape_chr_to_symbol_map[10] = r'\n' # optional escape
_cescape_chr_to_symbol_map[13] = r'\r' # optional escape def _MakeStrEscapes():
_cescape_chr_to_symbol_map[34] = r'\"' # necessary escape ret = {}
_cescape_chr_to_symbol_map[39] = r"\'" # optional escape for i in range(0, 128):
_cescape_chr_to_symbol_map[92] = r'\\' # necessary escape if not _AsciiIsPrint(i):
ret[i] = r'\%03o' % i
# Lookup table for unicode ret[ord('\t')] = r'\t' # optional escape
_cescape_unicode_to_str = [chr(i) for i in range(0, 256)] ret[ord('\n')] = r'\n' # optional escape
for byte, string in _cescape_chr_to_symbol_map.items(): ret[ord('\r')] = r'\r' # optional escape
_cescape_unicode_to_str[byte] = string ret[ord('"')] = r'\"' # necessary escape
ret[ord('\'')] = r"\'" # optional escape
# Lookup table for non-utf8, with necessary escapes at (o >= 127 or o < 32) ret[ord('\\')] = r'\\' # necessary escape
_cescape_byte_to_str = ([r'\%03o' % i for i in range(0, 32)] + return ret
[chr(i) for i in range(32, 127)] +
[r'\%03o' % i for i in range(127, 256)]) # Maps int -> char, performing string escapes.
for byte, string in _cescape_chr_to_symbol_map.items(): _str_escapes = _MakeStrEscapes()
_cescape_byte_to_str[byte] = string
del byte, string # Maps int -> char, performing byte escaping and string escapes
_byte_escapes = {i: chr(i) for i in range(0, 256)}
_byte_escapes.update(_str_escapes)
_byte_escapes.update({i: r'\%03o' % i for i in range(128, 256)})
def _DecodeUtf8EscapeErrors(text_bytes):
ret = ''
while text_bytes:
try:
ret += text_bytes.decode('utf-8').translate(_str_escapes)
text_bytes = ''
except UnicodeDecodeError as e:
ret += text_bytes[:e.start].decode('utf-8').translate(_str_escapes)
ret += _byte_escapes[text_bytes[e.start]]
text_bytes = text_bytes[e.start+1:]
return ret
def CEscape(text, as_utf8) -> str: def CEscape(text, as_utf8) -> str:
@ -47,13 +63,15 @@ def CEscape(text, as_utf8) -> str:
# length. So, "\0011".encode('string_escape') ends up being "\\x011", which # length. So, "\0011".encode('string_escape') ends up being "\\x011", which
# will be decoded in C++ as a single-character string with char code 0x11. # will be decoded in C++ as a single-character string with char code 0x11.
text_is_unicode = isinstance(text, str) text_is_unicode = isinstance(text, str)
if as_utf8 and text_is_unicode:
# We're already unicode, no processing beyond control char escapes.
return text.translate(_cescape_chr_to_symbol_map)
ord_ = ord if text_is_unicode else lambda x: x # bytes iterate as ints.
if as_utf8: if as_utf8:
return ''.join(_cescape_unicode_to_str[ord_(c)] for c in text) if text_is_unicode:
return ''.join(_cescape_byte_to_str[ord_(c)] for c in text) return text.translate(_str_escapes)
else:
return _DecodeUtf8EscapeErrors(text)
else:
if text_is_unicode:
text = text.encode('utf-8')
return ''.join([_byte_escapes[c] for c in text])
_CUNESCAPE_HEX = re.compile(r'(\\+)x([0-9a-fA-F])(?![0-9a-fA-F])') _CUNESCAPE_HEX = re.compile(r'(\\+)x([0-9a-fA-F])(?![0-9a-fA-F])')

@ -46,6 +46,8 @@ _QUOTES = frozenset(("'", '"'))
_ANY_FULL_TYPE_NAME = 'google.protobuf.Any' _ANY_FULL_TYPE_NAME = 'google.protobuf.Any'
_DEBUG_STRING_SILENT_MARKER = '\t ' _DEBUG_STRING_SILENT_MARKER = '\t '
_as_utf8_default = True
class Error(Exception): class Error(Exception):
"""Top-level module error for text_format.""" """Top-level module error for text_format."""
@ -91,7 +93,7 @@ class TextWriter(object):
def MessageToString( def MessageToString(
message, message,
as_utf8=False, as_utf8=_as_utf8_default,
as_one_line=False, as_one_line=False,
use_short_repeated_primitives=False, use_short_repeated_primitives=False,
pointy_brackets=False, pointy_brackets=False,
@ -186,7 +188,7 @@ def _IsMapEntry(field):
def PrintMessage(message, def PrintMessage(message,
out, out,
indent=0, indent=0,
as_utf8=False, as_utf8=_as_utf8_default,
as_one_line=False, as_one_line=False,
use_short_repeated_primitives=False, use_short_repeated_primitives=False,
pointy_brackets=False, pointy_brackets=False,
@ -229,7 +231,7 @@ def PrintMessage(message,
the field is a proto message. the field is a proto message.
""" """
printer = _Printer( printer = _Printer(
out=out, indent=indent, as_utf8=as_utf8, out=out, indent=indent, as_utf8=_as_utf8_default,
as_one_line=as_one_line, as_one_line=as_one_line,
use_short_repeated_primitives=use_short_repeated_primitives, use_short_repeated_primitives=use_short_repeated_primitives,
pointy_brackets=pointy_brackets, pointy_brackets=pointy_brackets,
@ -248,7 +250,7 @@ def PrintField(field,
value, value,
out, out,
indent=0, indent=0,
as_utf8=False, as_utf8=_as_utf8_default,
as_one_line=False, as_one_line=False,
use_short_repeated_primitives=False, use_short_repeated_primitives=False,
pointy_brackets=False, pointy_brackets=False,
@ -272,7 +274,7 @@ def PrintFieldValue(field,
value, value,
out, out,
indent=0, indent=0,
as_utf8=False, as_utf8=_as_utf8_default,
as_one_line=False, as_one_line=False,
use_short_repeated_primitives=False, use_short_repeated_primitives=False,
pointy_brackets=False, pointy_brackets=False,
@ -328,7 +330,7 @@ class _Printer(object):
self, self,
out, out,
indent=0, indent=0,
as_utf8=False, as_utf8=_as_utf8_default,
as_one_line=False, as_one_line=False,
use_short_repeated_primitives=False, use_short_repeated_primitives=False,
pointy_brackets=False, pointy_brackets=False,

Loading…
Cancel
Save