From 57772cdae7a97495efdbf2f5438801890cb0f93b Mon Sep 17 00:00:00 2001 From: Kyle Montemayor Date: Tue, 16 Apr 2024 10:39:47 -0700 Subject: [PATCH] Add __or__ to enum_type_wrapper so they can be used in type unions PiperOrigin-RevId: 625381202 --- .../protobuf/internal/enum_type_wrapper.py | 4 ++ .../internal/enum_type_wrapper_test.py | 39 +++++++++++++++++++ 2 files changed, 43 insertions(+) create mode 100644 python/google/protobuf/internal/enum_type_wrapper_test.py diff --git a/python/google/protobuf/internal/enum_type_wrapper.py b/python/google/protobuf/internal/enum_type_wrapper.py index cc65fc0bc8..da7d559a35 100644 --- a/python/google/protobuf/internal/enum_type_wrapper.py +++ b/python/google/protobuf/internal/enum_type_wrapper.py @@ -99,3 +99,7 @@ class EnumTypeWrapper(object): pass # fall out to break exception chaining raise AttributeError('Enum {} has no value defined for name {!r}'.format( self._enum_type.name, name)) + + def __or__(self, other): + """Returns the union type of self and other.""" + return type(self) | other diff --git a/python/google/protobuf/internal/enum_type_wrapper_test.py b/python/google/protobuf/internal/enum_type_wrapper_test.py new file mode 100644 index 0000000000..623408aabd --- /dev/null +++ b/python/google/protobuf/internal/enum_type_wrapper_test.py @@ -0,0 +1,39 @@ +# Protocol Buffers - Google's data interchange format +# Copyright 2008 Google Inc. All rights reserved. +# +# Use of this source code is governed by a BSD-style +# license that can be found in the LICENSE file or at +# https://developers.google.com/open-source/licenses/bsd +"""Tests for EnumTypeWrapper.""" + +__author__ = "kmonte@google.com (Kyle Montemayor)" + +import types +import unittest + +from google.protobuf.internal import enum_type_wrapper + +from google.protobuf import unittest_pb2 + + +class EnumTypeWrapperTest(unittest.TestCase): + + def test_type_union(self): + enum_type = enum_type_wrapper.EnumTypeWrapper( + unittest_pb2.TestAllTypes.NestedEnum.DESCRIPTOR + ) + union_type = enum_type | int + self.assertIsInstance(union_type, types.UnionType) + + def get_union() -> union_type: + return enum_type + + union = get_union() + self.assertIsInstance(union, enum_type_wrapper.EnumTypeWrapper) + self.assertEqual( + union.DESCRIPTOR, unittest_pb2.TestAllTypes.NestedEnum.DESCRIPTOR + ) + + +if __name__ == "__main__": + unittest.main()