From cac05fe0d1fe26c54a2444c8d954f72c37387506 Mon Sep 17 00:00:00 2001 From: Hong Shin Date: Thu, 19 Dec 2024 06:47:45 -0800 Subject: [PATCH] hpb: Introduce UPB_EXT_PRIMITIVE to flesh out all remaining scalars for extensions In this CL, we add the macro UPB_EXT_PRIMITIVE. The template specializations are practically identical sans the CppType and UpbFunc called, so we now consolidate via this macro. Added support for uint32/64, float/double, and bool. Getting and setting exts of ^ in hpb should all work, and fetch the proper default value as well (if provided in the .proto). PiperOrigin-RevId: 707897721 --- hpb/extension.h | 53 +++++++------- hpb_generator/gen_utils.cc | 10 +++ hpb_generator/tests/extension_test.cc | 88 +++++++++++++++++++++++- hpb_generator/tests/test_extension.proto | 9 ++- 4 files changed, 129 insertions(+), 31 deletions(-) diff --git a/hpb/extension.h b/hpb/extension.h index b55888c4c1..8631e4e1d0 100644 --- a/hpb/extension.h +++ b/hpb/extension.h @@ -70,32 +70,31 @@ struct UpbExtensionTrait> { } }; -template <> -struct UpbExtensionTrait { - using DefaultType = int32_t; - using ReturnType = int32_t; - static constexpr auto kSetter = upb_Message_SetExtensionInt32; - - template - static constexpr ReturnType Get(Msg message, const Id& id) { - auto default_val = hpb::internal::PrivateAccess::GetDefaultValue(id); - return upb_Message_GetExtensionInt32(hpb::interop::upb::GetMessage(message), - id.mini_table_ext(), default_val); - } -}; - -template <> -struct UpbExtensionTrait { - using DefaultType = int64_t; - using ReturnType = int64_t; - static constexpr auto kSetter = upb_Message_SetExtensionInt64; - template - static constexpr ReturnType Get(Msg message, const Id& id) { - auto default_val = hpb::internal::PrivateAccess::GetDefaultValue(id); - return upb_Message_GetExtensionInt64(hpb::interop::upb::GetMessage(message), - id.mini_table_ext(), default_val); - } -}; +#define UPB_EXT_PRIMITIVE(CppType, UpbFunc) \ + template <> \ + struct UpbExtensionTrait { \ + using DefaultType = CppType; \ + using ReturnType = CppType; \ + static constexpr auto kSetter = upb_Message_SetExtension##UpbFunc; \ + \ + template \ + static constexpr ReturnType Get(Msg message, const Id& id) { \ + auto default_val = hpb::internal::PrivateAccess::GetDefaultValue(id); \ + return upb_Message_GetExtension##UpbFunc( \ + hpb::interop::upb::GetMessage(message), id.mini_table_ext(), \ + default_val); \ + } \ + }; + +UPB_EXT_PRIMITIVE(bool, Bool); +UPB_EXT_PRIMITIVE(int32_t, Int32); +UPB_EXT_PRIMITIVE(int64_t, Int64); +UPB_EXT_PRIMITIVE(uint32_t, UInt32); +UPB_EXT_PRIMITIVE(uint64_t, UInt64); +UPB_EXT_PRIMITIVE(float, Float); +UPB_EXT_PRIMITIVE(double, Double); + +#undef UPB_EXT_PRIMITIVE // TODO: b/375460289 - flesh out non-promotional msg support that does // not return an error if missing but the default msg @@ -245,7 +244,7 @@ absl::Status SetExtension( Ptr message, const ::hpb::internal::ExtensionIdentifier& id, const Extension& value) { - if constexpr (std::is_integral_v) { + if constexpr (std::is_arithmetic_v) { bool res = hpb::internal::UpbExtensionTrait::kSetter( hpb::interop::upb::GetMessage(message), id.mini_table_ext(), value, hpb::interop::upb::GetArena(message)); diff --git a/hpb_generator/gen_utils.cc b/hpb_generator/gen_utils.cc index 8b035da7e0..46341638eb 100644 --- a/hpb_generator/gen_utils.cc +++ b/hpb_generator/gen_utils.cc @@ -139,6 +139,16 @@ std::string DefaultValue(const FieldDescriptor* field) { return absl::StrCat(field->default_value_int32()); case FieldDescriptor::CPPTYPE_INT64: return absl::StrCat(field->default_value_int64()); + case FieldDescriptor::CPPTYPE_UINT32: + return absl::StrCat(field->default_value_uint32()); + case FieldDescriptor::CPPTYPE_UINT64: + return absl::StrCat(field->default_value_uint64()); + case FieldDescriptor::CPPTYPE_FLOAT: + return absl::StrCat(field->default_value_float()); + case FieldDescriptor::CPPTYPE_DOUBLE: + return absl::StrCat(field->default_value_double()); + case FieldDescriptor::CPPTYPE_BOOL: + return field->default_value_bool() ? "true" : "false"; case FieldDescriptor::CPPTYPE_MESSAGE: return "::std::false_type()"; default: diff --git a/hpb_generator/tests/extension_test.cc b/hpb_generator/tests/extension_test.cc index 3b6591fb6c..aa9e2b5619 100644 --- a/hpb_generator/tests/extension_test.cc +++ b/hpb_generator/tests/extension_test.cc @@ -7,6 +7,9 @@ #include "google/protobuf/hpb/extension.h" +#include +#include + #include #include #include "google/protobuf/compiler/hpb/tests/child_model.upb.proto.h" @@ -23,11 +26,16 @@ using ::hpb_unittest::protos::other_ext; using ::hpb_unittest::protos::TestModel; using ::hpb_unittest::protos::theme; using ::hpb_unittest::protos::ThemeExtension; +using ::hpb_unittest::someotherpackage::protos::bool_ext; +using ::hpb_unittest::someotherpackage::protos::double_ext; +using ::hpb_unittest::someotherpackage::protos::float_ext; using ::hpb_unittest::someotherpackage::protos::int32_ext; using ::hpb_unittest::someotherpackage::protos::int64_ext; using ::hpb_unittest::someotherpackage::protos::repeated_int32_ext; using ::hpb_unittest::someotherpackage::protos::repeated_int64_ext; using ::hpb_unittest::someotherpackage::protos::repeated_string_ext; +using ::hpb_unittest::someotherpackage::protos::uint32_ext; +using ::hpb_unittest::someotherpackage::protos::uint64_ext; using ::testing::status::IsOkAndHolds; @@ -55,7 +63,7 @@ TEST(CppGeneratedCode, ClearExtensionWithEmptyExtensionPtr) { EXPECT_EQ(false, ::hpb::HasExtension(recursive_child, theme)); } -TEST(CppGeneratedCode, SetExtensionInt32) { +TEST(CppGeneratedCode, GetSetExtensionInt32) { TestModel model; EXPECT_EQ(false, hpb::HasExtension(&model, int32_ext)); int32_t val = 55; @@ -64,7 +72,7 @@ TEST(CppGeneratedCode, SetExtensionInt32) { EXPECT_THAT(hpb::GetExtension(&model, int32_ext), IsOkAndHolds(val)); } -TEST(CppGeneratedCode, SetExtensionInt64) { +TEST(CppGeneratedCode, GetSetExtensionInt64) { TestModel model; EXPECT_EQ(false, hpb::HasExtension(&model, int64_ext)); int64_t val = std::numeric_limits::max() + int64_t{1}; @@ -73,6 +81,50 @@ TEST(CppGeneratedCode, SetExtensionInt64) { EXPECT_THAT(hpb::GetExtension(&model, int64_ext), IsOkAndHolds(val)); } +TEST(CppGeneratedCode, GetSetExtensionUInt32) { + TestModel model; + EXPECT_EQ(false, hpb::HasExtension(&model, uint32_ext)); + uint32_t val = std::numeric_limits::max() + uint32_t{5}; + auto x = hpb::SetExtension(&model, uint32_ext, val); + EXPECT_EQ(true, hpb::HasExtension(&model, uint32_ext)); + EXPECT_THAT(hpb::GetExtension(&model, uint32_ext), IsOkAndHolds(val)); +} + +TEST(CppGeneratedCode, GetSetExtensionUInt64) { + TestModel model; + EXPECT_EQ(false, hpb::HasExtension(&model, uint64_ext)); + uint64_t val = std::numeric_limits::max() + uint64_t{5}; + auto x = hpb::SetExtension(&model, uint64_ext, val); + EXPECT_EQ(true, hpb::HasExtension(&model, uint64_ext)); + EXPECT_THAT(hpb::GetExtension(&model, uint64_ext), IsOkAndHolds(val)); +} + +TEST(CppGeneratedCode, GetSetExtensionFloat) { + TestModel model; + EXPECT_EQ(false, hpb::HasExtension(&model, float_ext)); + float val = 2.78; + auto x = hpb::SetExtension(&model, float_ext, val); + EXPECT_EQ(true, hpb::HasExtension(&model, float_ext)); + EXPECT_THAT(hpb::GetExtension(&model, float_ext), IsOkAndHolds(val)); +} + +TEST(CppGeneratedCode, GetSetExtensionDouble) { + TestModel model; + EXPECT_EQ(false, hpb::HasExtension(&model, double_ext)); + double val = std::numeric_limits::max() + 1.23; + auto x = hpb::SetExtension(&model, double_ext, val); + EXPECT_EQ(true, hpb::HasExtension(&model, double_ext)); + EXPECT_THAT(hpb::GetExtension(&model, double_ext), IsOkAndHolds(val)); +} + +TEST(CppGeneratedCode, GetSetExtensionBool) { + TestModel model; + EXPECT_EQ(false, hpb::HasExtension(&model, bool_ext)); + auto x = hpb::SetExtension(&model, bool_ext, true); + EXPECT_EQ(true, hpb::HasExtension(&model, bool_ext)); + EXPECT_THAT(hpb::GetExtension(&model, bool_ext), IsOkAndHolds(true)); +} + TEST(CppGeneratedCode, SetExtension) { TestModel model; void* prior_message; @@ -235,6 +287,38 @@ TEST(CppGeneratedCode, GetExtensionInt64WithDefault) { EXPECT_EQ(*res, expected); } +TEST(CppGeneratedCode, GetExtensionUInt32WithDefault) { + TestModel model; + auto res = hpb::GetExtension(&model, uint32_ext); + EXPECT_THAT(res, IsOkAndHolds(12)); +} + +TEST(CppGeneratedCode, GetExtensionUInt64WithDefault) { + TestModel model; + auto res = hpb::GetExtension(&model, uint64_ext); + EXPECT_THAT(res, IsOkAndHolds(4294967296)); +} + +TEST(CppGeneratedCode, GetExtensionFloatWithDefault) { + TestModel model; + auto res = hpb::GetExtension(&model, float_ext); + static_assert(std::is_same_v>); + EXPECT_THAT(res, IsOkAndHolds(3.14f)); +} + +TEST(CppGeneratedCode, GetExtensionDoubleWithDefault) { + TestModel model; + auto res = hpb::GetExtension(&model, double_ext); + static_assert(std::is_same_v>); + EXPECT_THAT(res, IsOkAndHolds(340282000000000000000000000000000000001.23)); +} + +TEST(CppGeneratedCode, GetExtensionBoolWithDefault) { + TestModel model; + auto res = hpb::GetExtension(&model, bool_ext); + EXPECT_THAT(res, IsOkAndHolds(true)); +} + TEST(CppGeneratedCode, GetExtensionOnMutableChild) { TestModel model; ThemeExtension extension1; diff --git a/hpb_generator/tests/test_extension.proto b/hpb_generator/tests/test_extension.proto index 30ca3cfdfc..d4b0e8c335 100644 --- a/hpb_generator/tests/test_extension.proto +++ b/hpb_generator/tests/test_extension.proto @@ -20,11 +20,16 @@ extend TestModel { extend TestModel { int32 int32_ext = 13002 [default = 644]; - int64 int64_ext = 13003 [default = 2147483648]; repeated int32 repeated_int32_ext = 13004; repeated int64 repeated_int64_ext = 13005; - repeated string repeated_string_ext = 13006; + + uint32 uint32_ext = 13007 [default = 12]; + uint64 uint64_ext = 13008 [default = 4294967296]; + float float_ext = 13009 [default = 3.14]; + double double_ext = 13010 + [default = 340282000000000000000000000000000000001.23]; + bool bool_ext = 13011 [default = true]; }