From 3e966f1dc98ac2ea9c266e915a387db378527ac0 Mon Sep 17 00:00:00 2001 From: Jie Luo Date: Fri, 1 Sep 2023 14:10:33 -0700 Subject: [PATCH] Add lock to pure python's field decoders Fix data race that may increase flake rate of some tests. PiperOrigin-RevId: 562044867 --- .../protobuf/internal/python_message.py | 8 +- .../protobuf/internal/thread_safe_test.py | 78 +++++++++++++++++++ 2 files changed, 83 insertions(+), 3 deletions(-) create mode 100644 python/google/protobuf/internal/thread_safe_test.py diff --git a/python/google/protobuf/internal/python_message.py b/python/google/protobuf/internal/python_message.py index 2bd8bc228a..03966bbbf5 100755 --- a/python/google/protobuf/internal/python_message.py +++ b/python/google/protobuf/internal/python_message.py @@ -328,8 +328,8 @@ def _MaybeAddEncoder(cls, field_descriptor): sizer = type_checkers.TYPE_TO_SIZER[field_descriptor.type]( field_descriptor.number, is_repeated, is_packed) - field_descriptor._encoder = field_encoder field_descriptor._sizer = sizer + field_descriptor._encoder = field_encoder def _MaybeAddDecoder(cls, field_descriptor): @@ -338,7 +338,7 @@ def _MaybeAddDecoder(cls, field_descriptor): is_repeated = field_descriptor.label == _FieldDescriptor.LABEL_REPEATED is_map_entry = _IsMapField(field_descriptor) - field_descriptor._decoders = {} + helper_decoders = {} def AddDecoder(is_packed): decode_type = field_descriptor.type @@ -372,7 +372,7 @@ def _MaybeAddDecoder(cls, field_descriptor): field_descriptor, field_descriptor._default_constructor, not field_descriptor.has_presence) - field_descriptor._decoders[is_packed] = field_decoder + helper_decoders[is_packed] = field_decoder AddDecoder(False) @@ -381,6 +381,8 @@ def _MaybeAddDecoder(cls, field_descriptor): # packed values regardless of the field's options. AddDecoder(True) + field_descriptor._decoders = helper_decoders + def _AddClassAttributesForNestedExtensions(descriptor, dictionary): extensions = descriptor.extensions_by_name diff --git a/python/google/protobuf/internal/thread_safe_test.py b/python/google/protobuf/internal/thread_safe_test.py new file mode 100644 index 0000000000..2edfdca9b9 --- /dev/null +++ b/python/google/protobuf/internal/thread_safe_test.py @@ -0,0 +1,78 @@ +# Protocol Buffers - Google's data interchange format +# Copyright 2008 Google Inc. All rights reserved. +# https://developers.google.com/protocol-buffers/ +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are +# met: +# +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above +# copyright notice, this list of conditions and the following disclaimer +# in the documentation and/or other materials provided with the +# distribution. +# * Neither the name of Google Inc. nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +"""Unittest for thread safe""" + +import threading +import time +import unittest + +from google.protobuf import unittest_pb2 + + +class ThreadSafeTest(unittest.TestCase): + + def setUp(self): + self.success = 0 + + def testFieldDecodersDataRace(self): + msg = unittest_pb2.TestAllTypes(optional_int32=1) + serialized_data = msg.SerializeToString() + lock = threading.Lock() + + def ParseMessage(): + parsed_msg = unittest_pb2.TestAllTypes() + time.sleep(0.005) + parsed_msg.ParseFromString(serialized_data) + with lock: + if msg == parsed_msg: + self.success += 1 + + field_des = unittest_pb2.TestAllTypes.DESCRIPTOR.fields_by_name[ + 'optional_int32' + ] + count = 5000 + for x in range(0, count): + # delete the _decoders because only the first time parse the field + # may cause data race. + if hasattr(field_des, '_decoders'): + delattr(field_des, '_decoders') + thread1 = threading.Thread(target=ParseMessage) + thread2 = threading.Thread(target=ParseMessage) + thread1.start() + thread2.start() + thread1.join() + thread2.join() + + self.assertEqual(count * 2, self.success) + + +if __name__ == '__main__': + unittest.main()