Implement GetExtension for numeric (int32 and int64)

At the moment, hpb's public api solely returns Ptr<const Extension>. We'd like to support all non-msg types like int32, int64, bool etc.
These should not return a Ptr<...> but the underlying primitive itself.

We start by adding support for int32 and int64.

PiperOrigin-RevId: 691490444
pull/18868/head
Hong Shin 4 weeks ago committed by Copybara-Service
parent adc8718150
commit 4f6f2dd873
  1. 68
      hpb/extension.h
  2. 5
      hpb/internal/internal.h
  3. 6
      hpb_generator/gen_extensions.cc
  4. 18
      hpb_generator/gen_utils.cc
  5. 2
      hpb_generator/gen_utils.h
  6. 5
      hpb_generator/tests/test_extension.proto
  7. 18
      hpb_generator/tests/test_generated.cc

@ -9,15 +9,19 @@
#define GOOGLE_PROTOBUF_HPB_EXTENSION_H__ #define GOOGLE_PROTOBUF_HPB_EXTENSION_H__
#include <cstdint> #include <cstdint>
#include <type_traits>
#include <vector> #include <vector>
#include "absl/base/attributes.h" #include "absl/base/attributes.h"
#include "absl/status/statusor.h"
#include "google/protobuf/hpb/backend/upb/interop.h" #include "google/protobuf/hpb/backend/upb/interop.h"
#include "google/protobuf/hpb/internal/message_lock.h" #include "google/protobuf/hpb/internal/message_lock.h"
#include "google/protobuf/hpb/internal/template_help.h" #include "google/protobuf/hpb/internal/template_help.h"
#include "google/protobuf/hpb/ptr.h" #include "google/protobuf/hpb/ptr.h"
#include "google/protobuf/hpb/status.h" #include "google/protobuf/hpb/status.h"
#include "upb/mem/arena.hpp" #include "upb/mem/arena.hpp"
#include "upb/message/accessors.h"
#include "upb/message/array.h"
#include "upb/mini_table/extension.h" #include "upb/mini_table/extension.h"
#include "upb/mini_table/extension_registry.h" #include "upb/mini_table/extension_registry.h"
@ -34,6 +38,41 @@ absl::Status SetExtension(upb_Message* message, upb_Arena* message_arena,
const upb_MiniTableExtension* ext, const upb_MiniTableExtension* ext,
const upb_Message* extension); const upb_Message* extension);
/**
* Trait that maps upb extension types to the corresponding
* return value: ubp_MessageValue.
*
* All partial specializations must have:
* - DefaultType: the type of the default value.
* - ReturnType: the type of the return value.
* - kGetter: the corresponding upb_MessageValue upb_Message_GetExtension* func
*/
template <typename T, typename = void>
struct UpbExtensionTrait;
template <>
struct UpbExtensionTrait<int32_t> {
using DefaultType = int32_t;
using ReturnType = int32_t;
static constexpr auto kGetter = upb_Message_GetExtensionInt32;
};
template <>
struct UpbExtensionTrait<int64_t> {
using DefaultType = int64_t;
using ReturnType = int64_t;
static constexpr auto kGetter = upb_Message_GetExtensionInt64;
};
// TODO: b/375460289 - flesh out non-promotional msg support that does
// not return an error if missing but the default msg
template <typename T>
struct UpbExtensionTrait<T> {
using DefaultType = int;
using ReturnType = int;
using DefaultFuncType = void (*)();
};
// ------------------------------------------------------------------- // -------------------------------------------------------------------
// ExtensionIdentifier // ExtensionIdentifier
// This is the type of actual extension objects. E.g. if you have: // This is the type of actual extension objects. E.g. if you have:
@ -54,14 +93,24 @@ class ExtensionIdentifier {
} }
private: private:
constexpr explicit ExtensionIdentifier(const upb_MiniTableExtension* mte) constexpr explicit ExtensionIdentifier(
: mini_table_ext_(mte) {} const upb_MiniTableExtension* mte,
typename UpbExtensionTrait<ExtensionType>::DefaultType val)
: mini_table_ext_(mte), default_val_(val) {}
constexpr uint32_t number() const { constexpr uint32_t number() const {
return upb_MiniTableExtension_Number(mini_table_ext_); return upb_MiniTableExtension_Number(mini_table_ext_);
} }
friend struct PrivateAccess;
const upb_MiniTableExtension* mini_table_ext_; const upb_MiniTableExtension* mini_table_ext_;
typename UpbExtensionTrait<ExtensionType>::ReturnType default_value() const {
return default_val_;
}
typename UpbExtensionTrait<ExtensionType>::DefaultType default_val_;
friend struct PrivateAccess;
}; };
upb_ExtensionRegistry* GetUpbExtensions( upb_ExtensionRegistry* GetUpbExtensions(
@ -216,10 +265,19 @@ absl::StatusOr<Ptr<const Extension>> GetExtension(
template <typename T, typename Extendee, typename Extension, template <typename T, typename Extendee, typename Extension,
typename = hpb::internal::EnableIfHpbClass<T>> typename = hpb::internal::EnableIfHpbClass<T>>
absl::StatusOr<Ptr<const Extension>> GetExtension( decltype(auto) GetExtension(
const T* message, const T* message,
const ::hpb::internal::ExtensionIdentifier<Extendee, Extension>& id) { const hpb::internal::ExtensionIdentifier<Extendee, Extension>& id) {
if constexpr (std::is_integral_v<Extension>) {
auto default_val = hpb::internal::PrivateAccess::GetDefaultValue(id);
absl::StatusOr<Extension> res =
hpb::internal::UpbExtensionTrait<Extension>::kGetter(
hpb::interop::upb::GetMessage(message), id.mini_table_ext(),
default_val);
return res;
} else {
return GetExtension(Ptr(message), id); return GetExtension(Ptr(message), id);
}
} }
template <typename T, typename Extension> template <typename T, typename Extension>

@ -47,6 +47,11 @@ struct PrivateAccess {
static constexpr uint32_t GetExtensionNumber(const ExtensionId& id) { static constexpr uint32_t GetExtensionNumber(const ExtensionId& id) {
return id.number(); return id.number();
} }
template <typename ExtensionId>
static decltype(auto) GetDefaultValue(const ExtensionId& id) {
return id.default_value();
}
}; };
} // namespace hpb::internal } // namespace hpb::internal

@ -8,10 +8,13 @@
#include "google/protobuf/compiler/hpb/gen_extensions.h" #include "google/protobuf/compiler/hpb/gen_extensions.h"
#include <string> #include <string>
#include <vector>
#include "absl/strings/str_cat.h" #include "absl/strings/str_cat.h"
#include "google/protobuf/compiler/hpb/context.h" #include "google/protobuf/compiler/hpb/context.h"
#include "google/protobuf/compiler/hpb/gen_utils.h"
#include "google/protobuf/compiler/hpb/names.h" #include "google/protobuf/compiler/hpb/names.h"
#include "google/protobuf/descriptor.h"
#include "upb_generator/c/names.h" #include "upb_generator/c/names.h"
namespace google::protobuf::hpb_generator { namespace google::protobuf::hpb_generator {
@ -73,6 +76,7 @@ void WriteExtensionIdentifier(const protobuf::FieldDescriptor* ext,
{{"containing_type_name", ContainingTypeName(ext)}, {{"containing_type_name", ContainingTypeName(ext)},
{"mini_table_name", mini_table_name}, {"mini_table_name", mini_table_name},
{"ext_name", ext->name()}, {"ext_name", ext->name()},
{"default_value", DefaultValue(ext)},
{"ext_type", CppTypeParameterName(ext)}, {"ext_type", CppTypeParameterName(ext)},
{"class_prefix", class_prefix}}, {"class_prefix", class_prefix}},
R"cc( R"cc(
@ -82,7 +86,7 @@ void WriteExtensionIdentifier(const protobuf::FieldDescriptor* ext,
::hpb::internal::PrivateAccess::InvokeConstructor< ::hpb::internal::PrivateAccess::InvokeConstructor<
::hpb::internal::ExtensionIdentifier<$containing_type_name$, ::hpb::internal::ExtensionIdentifier<$containing_type_name$,
$ext_type$>>( $ext_type$>>(
&$mini_table_name$); &$mini_table_name$, $default_value$);
)cc"); )cc");
} }

@ -11,8 +11,11 @@
#include <string> #include <string>
#include <vector> #include <vector>
#include "absl/log/absl_log.h"
#include "absl/strings/ascii.h" #include "absl/strings/ascii.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/string_view.h" #include "absl/strings/string_view.h"
#include "google/protobuf/descriptor.h"
namespace google::protobuf::hpb_generator { namespace google::protobuf::hpb_generator {
@ -127,5 +130,20 @@ std::string ToCamelCase(const absl::string_view input, bool lower_first) {
return result; return result;
} }
std::string DefaultValue(const FieldDescriptor* field) {
switch (field->cpp_type()) {
case FieldDescriptor::CPPTYPE_INT32:
return absl::StrCat(field->default_value_int32());
case FieldDescriptor::CPPTYPE_INT64:
return absl::StrCat(field->default_value_int64());
default:
// TODO: b/375460289 - implement rest of scalars + msg
ABSL_LOG(WARNING) << "Unsupported default value type (in-progress): <"
<< field->cpp_type_name()
<< "> For field: " << field->full_name();
return "-1";
}
}
} // namespace protobuf } // namespace protobuf
} // namespace google::hpb_generator } // namespace google::hpb_generator

@ -41,6 +41,8 @@ std::vector<const protobuf::FieldDescriptor*> FieldNumberOrder(
std::string ToCamelCase(absl::string_view input, bool lower_first); std::string ToCamelCase(absl::string_view input, bool lower_first);
std::string DefaultValue(const FieldDescriptor* field);
} // namespace protobuf } // namespace protobuf
} // namespace google::hpb_generator } // namespace google::hpb_generator

@ -17,3 +17,8 @@ import "google/protobuf/compiler/hpb/tests/test_model.proto";
extend TestModel { extend TestModel {
optional ThemeExtension styling = 13001; optional ThemeExtension styling = 13001;
} }
extend TestModel {
optional int32 int32_ext = 13002 [default = 644];
optional int64 int64_ext = 13003 [default = 2147483648];
}

@ -5,6 +5,7 @@
// license that can be found in the LICENSE file or at // license that can be found in the LICENSE file or at
// https://developers.google.com/open-source/licenses/bsd // https://developers.google.com/open-source/licenses/bsd
#include <climits>
#include <cstdint> #include <cstdint>
#include <iterator> #include <iterator>
#include <limits> #include <limits>
@ -54,6 +55,8 @@ using ::hpb_unittest::protos::TestModel_Category_NEWS;
using ::hpb_unittest::protos::TestModel_Category_VIDEO; using ::hpb_unittest::protos::TestModel_Category_VIDEO;
using ::hpb_unittest::protos::theme; using ::hpb_unittest::protos::theme;
using ::hpb_unittest::protos::ThemeExtension; using ::hpb_unittest::protos::ThemeExtension;
using ::hpb_unittest::someotherpackage::protos::int32_ext;
using ::hpb_unittest::someotherpackage::protos::int64_ext;
using ::testing::ElementsAre; using ::testing::ElementsAre;
TEST(CppGeneratedCode, Constructor) { TestModel test_model; } TEST(CppGeneratedCode, Constructor) { TestModel test_model; }
@ -878,6 +881,21 @@ TEST(CppGeneratedCode, GetExtension) {
hpb::GetExtension(&model, theme).value()->ext_name()); hpb::GetExtension(&model, theme).value()->ext_name());
} }
TEST(CppGeneratedCode, GetExtensionInt32WithDefault) {
TestModel model;
auto res = hpb::GetExtension(&model, int32_ext);
EXPECT_TRUE(res.ok());
EXPECT_EQ(*res, 644);
}
TEST(CppGeneratedCode, GetExtensionInt64WithDefault) {
TestModel model;
auto res = hpb::GetExtension(&model, int64_ext);
EXPECT_TRUE(res.ok());
int64_t expected = std::numeric_limits<int32_t>::max() + int64_t{1};
EXPECT_EQ(*res, expected);
}
TEST(CppGeneratedCode, GetExtensionOnMutableChild) { TEST(CppGeneratedCode, GetExtensionOnMutableChild) {
TestModel model; TestModel model;
ThemeExtension extension1; ThemeExtension extension1;

Loading…
Cancel
Save