Add lock to pure python's field decoders

Fix data race that may increase flake rate of some tests.

PiperOrigin-RevId: 562044867
pull/13828/head
Jie Luo 1 year ago committed by Copybara-Service
parent e052928c94
commit 3e966f1dc9
  1. 8
      python/google/protobuf/internal/python_message.py
  2. 78
      python/google/protobuf/internal/thread_safe_test.py

@ -328,8 +328,8 @@ def _MaybeAddEncoder(cls, field_descriptor):
sizer = type_checkers.TYPE_TO_SIZER[field_descriptor.type](
field_descriptor.number, is_repeated, is_packed)
field_descriptor._encoder = field_encoder
field_descriptor._sizer = sizer
field_descriptor._encoder = field_encoder
def _MaybeAddDecoder(cls, field_descriptor):
@ -338,7 +338,7 @@ def _MaybeAddDecoder(cls, field_descriptor):
is_repeated = field_descriptor.label == _FieldDescriptor.LABEL_REPEATED
is_map_entry = _IsMapField(field_descriptor)
field_descriptor._decoders = {}
helper_decoders = {}
def AddDecoder(is_packed):
decode_type = field_descriptor.type
@ -372,7 +372,7 @@ def _MaybeAddDecoder(cls, field_descriptor):
field_descriptor, field_descriptor._default_constructor,
not field_descriptor.has_presence)
field_descriptor._decoders[is_packed] = field_decoder
helper_decoders[is_packed] = field_decoder
AddDecoder(False)
@ -381,6 +381,8 @@ def _MaybeAddDecoder(cls, field_descriptor):
# packed values regardless of the field's options.
AddDecoder(True)
field_descriptor._decoders = helper_decoders
def _AddClassAttributesForNestedExtensions(descriptor, dictionary):
extensions = descriptor.extensions_by_name

@ -0,0 +1,78 @@
# Protocol Buffers - Google's data interchange format
# Copyright 2008 Google Inc. All rights reserved.
# https://developers.google.com/protocol-buffers/
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are
# met:
#
# * Redistributions of source code must retain the above copyright
# notice, this list of conditions and the following disclaimer.
# * Redistributions in binary form must reproduce the above
# copyright notice, this list of conditions and the following disclaimer
# in the documentation and/or other materials provided with the
# distribution.
# * Neither the name of Google Inc. nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
"""Unittest for thread safe"""
import threading
import time
import unittest
from google.protobuf import unittest_pb2
class ThreadSafeTest(unittest.TestCase):
def setUp(self):
self.success = 0
def testFieldDecodersDataRace(self):
msg = unittest_pb2.TestAllTypes(optional_int32=1)
serialized_data = msg.SerializeToString()
lock = threading.Lock()
def ParseMessage():
parsed_msg = unittest_pb2.TestAllTypes()
time.sleep(0.005)
parsed_msg.ParseFromString(serialized_data)
with lock:
if msg == parsed_msg:
self.success += 1
field_des = unittest_pb2.TestAllTypes.DESCRIPTOR.fields_by_name[
'optional_int32'
]
count = 5000
for x in range(0, count):
# delete the _decoders because only the first time parse the field
# may cause data race.
if hasattr(field_des, '_decoders'):
delattr(field_des, '_decoders')
thread1 = threading.Thread(target=ParseMessage)
thread2 = threading.Thread(target=ParseMessage)
thread1.start()
thread2.start()
thread1.join()
thread2.join()
self.assertEqual(count * 2, self.success)
if __name__ == '__main__':
unittest.main()
Loading…
Cancel
Save