diff --git a/hpb/extension.h b/hpb/extension.h index 111bbc712d..ae8d23b19a 100644 --- a/hpb/extension.h +++ b/hpb/extension.h @@ -9,15 +9,19 @@ #define GOOGLE_PROTOBUF_HPB_EXTENSION_H__ #include +#include #include #include "absl/base/attributes.h" +#include "absl/status/statusor.h" #include "google/protobuf/hpb/backend/upb/interop.h" #include "google/protobuf/hpb/internal/message_lock.h" #include "google/protobuf/hpb/internal/template_help.h" #include "google/protobuf/hpb/ptr.h" #include "google/protobuf/hpb/status.h" #include "upb/mem/arena.hpp" +#include "upb/message/accessors.h" +#include "upb/message/array.h" #include "upb/mini_table/extension.h" #include "upb/mini_table/extension_registry.h" @@ -34,6 +38,41 @@ absl::Status SetExtension(upb_Message* message, upb_Arena* message_arena, const upb_MiniTableExtension* ext, const upb_Message* extension); +/** + * Trait that maps upb extension types to the corresponding + * return value: ubp_MessageValue. + * + * All partial specializations must have: + * - DefaultType: the type of the default value. + * - ReturnType: the type of the return value. + * - kGetter: the corresponding upb_MessageValue upb_Message_GetExtension* func + */ +template +struct UpbExtensionTrait; + +template <> +struct UpbExtensionTrait { + using DefaultType = int32_t; + using ReturnType = int32_t; + static constexpr auto kGetter = upb_Message_GetExtensionInt32; +}; + +template <> +struct UpbExtensionTrait { + using DefaultType = int64_t; + using ReturnType = int64_t; + static constexpr auto kGetter = upb_Message_GetExtensionInt64; +}; + +// TODO: b/375460289 - flesh out non-promotional msg support that does +// not return an error if missing but the default msg +template +struct UpbExtensionTrait { + using DefaultType = int; + using ReturnType = int; + using DefaultFuncType = void (*)(); +}; + // ------------------------------------------------------------------- // ExtensionIdentifier // This is the type of actual extension objects. E.g. if you have: @@ -54,14 +93,24 @@ class ExtensionIdentifier { } private: - constexpr explicit ExtensionIdentifier(const upb_MiniTableExtension* mte) - : mini_table_ext_(mte) {} + constexpr explicit ExtensionIdentifier( + const upb_MiniTableExtension* mte, + typename UpbExtensionTrait::DefaultType val) + : mini_table_ext_(mte), default_val_(val) {} + constexpr uint32_t number() const { return upb_MiniTableExtension_Number(mini_table_ext_); } - friend struct PrivateAccess; const upb_MiniTableExtension* mini_table_ext_; + + typename UpbExtensionTrait::ReturnType default_value() const { + return default_val_; + } + + typename UpbExtensionTrait::DefaultType default_val_; + + friend struct PrivateAccess; }; upb_ExtensionRegistry* GetUpbExtensions( @@ -216,10 +265,19 @@ absl::StatusOr> GetExtension( template > -absl::StatusOr> GetExtension( +decltype(auto) GetExtension( const T* message, - const ::hpb::internal::ExtensionIdentifier& id) { - return GetExtension(Ptr(message), id); + const hpb::internal::ExtensionIdentifier& id) { + if constexpr (std::is_integral_v) { + auto default_val = hpb::internal::PrivateAccess::GetDefaultValue(id); + absl::StatusOr res = + hpb::internal::UpbExtensionTrait::kGetter( + hpb::interop::upb::GetMessage(message), id.mini_table_ext(), + default_val); + return res; + } else { + return GetExtension(Ptr(message), id); + } } template diff --git a/hpb/internal/internal.h b/hpb/internal/internal.h index beedcf7c78..2634c3cf48 100644 --- a/hpb/internal/internal.h +++ b/hpb/internal/internal.h @@ -47,6 +47,11 @@ struct PrivateAccess { static constexpr uint32_t GetExtensionNumber(const ExtensionId& id) { return id.number(); } + + template + static decltype(auto) GetDefaultValue(const ExtensionId& id) { + return id.default_value(); + } }; } // namespace hpb::internal diff --git a/hpb_generator/gen_extensions.cc b/hpb_generator/gen_extensions.cc index 959da3db3a..cf9650d315 100644 --- a/hpb_generator/gen_extensions.cc +++ b/hpb_generator/gen_extensions.cc @@ -8,10 +8,13 @@ #include "google/protobuf/compiler/hpb/gen_extensions.h" #include +#include #include "absl/strings/str_cat.h" #include "google/protobuf/compiler/hpb/context.h" +#include "google/protobuf/compiler/hpb/gen_utils.h" #include "google/protobuf/compiler/hpb/names.h" +#include "google/protobuf/descriptor.h" #include "upb_generator/c/names.h" namespace google::protobuf::hpb_generator { @@ -73,6 +76,7 @@ void WriteExtensionIdentifier(const protobuf::FieldDescriptor* ext, {{"containing_type_name", ContainingTypeName(ext)}, {"mini_table_name", mini_table_name}, {"ext_name", ext->name()}, + {"default_value", DefaultValue(ext)}, {"ext_type", CppTypeParameterName(ext)}, {"class_prefix", class_prefix}}, R"cc( @@ -82,7 +86,7 @@ void WriteExtensionIdentifier(const protobuf::FieldDescriptor* ext, ::hpb::internal::PrivateAccess::InvokeConstructor< ::hpb::internal::ExtensionIdentifier<$containing_type_name$, $ext_type$>>( - &$mini_table_name$); + &$mini_table_name$, $default_value$); )cc"); } diff --git a/hpb_generator/gen_utils.cc b/hpb_generator/gen_utils.cc index f89afc99e2..d0238d0104 100644 --- a/hpb_generator/gen_utils.cc +++ b/hpb_generator/gen_utils.cc @@ -11,8 +11,11 @@ #include #include +#include "absl/log/absl_log.h" #include "absl/strings/ascii.h" +#include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" +#include "google/protobuf/descriptor.h" namespace google::protobuf::hpb_generator { @@ -127,5 +130,20 @@ std::string ToCamelCase(const absl::string_view input, bool lower_first) { return result; } +std::string DefaultValue(const FieldDescriptor* field) { + switch (field->cpp_type()) { + case FieldDescriptor::CPPTYPE_INT32: + return absl::StrCat(field->default_value_int32()); + case FieldDescriptor::CPPTYPE_INT64: + return absl::StrCat(field->default_value_int64()); + default: + // TODO: b/375460289 - implement rest of scalars + msg + ABSL_LOG(WARNING) << "Unsupported default value type (in-progress): <" + << field->cpp_type_name() + << "> For field: " << field->full_name(); + return "-1"; + } +} + } // namespace protobuf } // namespace google::hpb_generator diff --git a/hpb_generator/gen_utils.h b/hpb_generator/gen_utils.h index a83500c719..154824776e 100644 --- a/hpb_generator/gen_utils.h +++ b/hpb_generator/gen_utils.h @@ -41,6 +41,8 @@ std::vector FieldNumberOrder( std::string ToCamelCase(absl::string_view input, bool lower_first); +std::string DefaultValue(const FieldDescriptor* field); + } // namespace protobuf } // namespace google::hpb_generator diff --git a/hpb_generator/tests/test_extension.proto b/hpb_generator/tests/test_extension.proto index f7758da7df..e8ff9d5329 100644 --- a/hpb_generator/tests/test_extension.proto +++ b/hpb_generator/tests/test_extension.proto @@ -17,3 +17,8 @@ import "google/protobuf/compiler/hpb/tests/test_model.proto"; extend TestModel { optional ThemeExtension styling = 13001; } + +extend TestModel { + optional int32 int32_ext = 13002 [default = 644]; + optional int64 int64_ext = 13003 [default = 2147483648]; +} diff --git a/hpb_generator/tests/test_generated.cc b/hpb_generator/tests/test_generated.cc index f9fb8132f7..23c0cfe0e0 100644 --- a/hpb_generator/tests/test_generated.cc +++ b/hpb_generator/tests/test_generated.cc @@ -5,6 +5,7 @@ // license that can be found in the LICENSE file or at // https://developers.google.com/open-source/licenses/bsd +#include #include #include #include @@ -54,6 +55,8 @@ using ::hpb_unittest::protos::TestModel_Category_NEWS; using ::hpb_unittest::protos::TestModel_Category_VIDEO; using ::hpb_unittest::protos::theme; using ::hpb_unittest::protos::ThemeExtension; +using ::hpb_unittest::someotherpackage::protos::int32_ext; +using ::hpb_unittest::someotherpackage::protos::int64_ext; using ::testing::ElementsAre; TEST(CppGeneratedCode, Constructor) { TestModel test_model; } @@ -878,6 +881,21 @@ TEST(CppGeneratedCode, GetExtension) { hpb::GetExtension(&model, theme).value()->ext_name()); } +TEST(CppGeneratedCode, GetExtensionInt32WithDefault) { + TestModel model; + auto res = hpb::GetExtension(&model, int32_ext); + EXPECT_TRUE(res.ok()); + EXPECT_EQ(*res, 644); +} + +TEST(CppGeneratedCode, GetExtensionInt64WithDefault) { + TestModel model; + auto res = hpb::GetExtension(&model, int64_ext); + EXPECT_TRUE(res.ok()); + int64_t expected = std::numeric_limits::max() + int64_t{1}; + EXPECT_EQ(*res, expected); +} + TEST(CppGeneratedCode, GetExtensionOnMutableChild) { TestModel model; ThemeExtension extension1;