// Protocol Buffers - Google's data interchange format // Copyright 2023 Google LLC. All rights reserved. // // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file or at // https://developers.google.com/open-source/licenses/bsd #ifndef PROTOBUF_HPB_HPB_H_ #define PROTOBUF_HPB_HPB_H_ #include #include #include #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/string_view.h" #include "upb/base/status.hpp" #include "upb/mem/arena.hpp" #include "upb/message/copy.h" #include "upb/mini_table/extension.h" #include "upb/wire/decode.h" #include "upb/wire/encode.h" namespace hpb { class ExtensionRegistry; using Arena = ::upb::Arena; template using Proxy = std::conditional_t::value, typename std::remove_const_t::CProxy, typename T::Proxy>; // Provides convenient access to Proxy and CProxy message types. // // Using rebinding and handling of const, Ptr and Ptr // allows copying const with T* const and avoids using non-copyable Proxy types // directly. template class Ptr final { public: Ptr() = delete; // Implicit conversions Ptr(T* m) : p_(m) {} // NOLINT Ptr(const Proxy* p) : p_(*p) {} // NOLINT Ptr(Proxy p) : p_(p) {} // NOLINT Ptr(const Ptr& m) = default; Ptr& operator=(Ptr v) & { Proxy::Rebind(p_, v.p_); return *this; } Proxy operator*() const { return p_; } Proxy* operator->() const { return const_cast*>(std::addressof(p_)); } #ifdef __clang__ #pragma clang diagnostic push #pragma clang diagnostic ignored "-Wclass-conversion" #endif template ::value, int> = 0> operator Ptr() const { Proxy p(p_); return Ptr(&p); } #ifdef __clang__ #pragma clang diagnostic pop #endif private: Ptr(upb_Message* msg, upb_Arena* arena) : p_(msg, arena) {} // NOLINT friend class Ptr; friend typename T::Access; Proxy p_; }; // Suppress -Wctad-maybe-unsupported with our manual deduction guide template Ptr(T* m) -> Ptr; // TODO: b/354766950 - Move upb-specific chunks out of hpb header inline absl::string_view UpbStrToStringView(upb_StringView str) { return absl::string_view(str.data, str.size); } // TODO: update bzl and move to upb runtime / protos.cc. inline upb_StringView UpbStrFromStringView(absl::string_view str, upb_Arena* arena) { const size_t str_size = str.size(); char* buffer = static_cast(upb_Arena_Malloc(arena, str_size)); memcpy(buffer, str.data(), str_size); return upb_StringView_FromDataAndSize(buffer, str_size); } // begin:github_only // // This type exists to work around an absl type that has not yet been // // released. // struct SourceLocation { // static SourceLocation current() { return {}; } // absl::string_view file_name() { return ""; } // int line() { return 0; } // }; // end:github_only // begin:google_only using SourceLocation = absl::SourceLocation; // end:google_only absl::Status MessageAllocationError( SourceLocation loc = SourceLocation::current()); absl::Status ExtensionNotFoundError( int extension_number, SourceLocation loc = SourceLocation::current()); absl::Status MessageDecodeError(upb_DecodeStatus status, SourceLocation loc = SourceLocation::current()); absl::Status MessageEncodeError(upb_EncodeStatus status, SourceLocation loc = SourceLocation::current()); namespace internal { template struct RemovePtr; template struct RemovePtr> { using type = T; }; template struct RemovePtr { using type = T; }; template using RemovePtrT = typename RemovePtr::type; template , typename = std::enable_if_t>> using PtrOrRaw = T; template using EnableIfHpbClass = std::enable_if_t< std::is_base_of::value && std::is_base_of::value>; template using EnableIfMutableProto = std::enable_if_t::value>; struct PrivateAccess { template static auto* GetInternalMsg(T&& message) { return message->msg(); } template static auto Proxy(upb_Message* p, upb_Arena* arena) { return typename T::Proxy(p, arena); } template static auto CProxy(const upb_Message* p, upb_Arena* arena) { return typename T::CProxy(p, arena); } template static auto CreateMessage(upb_Arena* arena) { return typename T::Proxy(upb_Message_New(T::minitable(), arena), arena); } template static constexpr uint32_t GetExtensionNumber(const ExtensionId& id) { return id.number(); } }; template auto* GetInternalMsg(T&& message) { return PrivateAccess::GetInternalMsg(std::forward(message)); } template T CreateMessage() { return T(); } template typename T::Proxy CreateMessageProxy(upb_Message* msg, upb_Arena* arena) { return typename T::Proxy(msg, arena); } template typename T::CProxy CreateMessage(const upb_Message* msg, upb_Arena* arena) { return PrivateAccess::CProxy(msg, arena); } class ExtensionMiniTableProvider { public: constexpr explicit ExtensionMiniTableProvider( const upb_MiniTableExtension* mini_table_ext) : mini_table_ext_(mini_table_ext) {} const upb_MiniTableExtension* mini_table_ext() const { return mini_table_ext_; } private: const upb_MiniTableExtension* mini_table_ext_; }; // ------------------------------------------------------------------- // ExtensionIdentifier // This is the type of actual extension objects. E.g. if you have: // extend Foo { // optional MyExtension bar = 1234; // } // then "bar" will be defined in C++ as: // ExtensionIdentifier bar(&namespace_bar_ext); template class ExtensionIdentifier : public ExtensionMiniTableProvider { public: using Extension = ExtensionType; using Extendee = ExtendeeType; constexpr explicit ExtensionIdentifier( const upb_MiniTableExtension* mini_table_ext) : ExtensionMiniTableProvider(mini_table_ext) {} private: constexpr uint32_t number() const { return upb_MiniTableExtension_Number(mini_table_ext()); } friend class PrivateAccess; }; template upb_Arena* GetArena(Ptr message) { return static_cast(message->GetInternalArena()); } template upb_Arena* GetArena(T* message) { return static_cast(message->GetInternalArena()); } template const upb_MiniTable* GetMiniTable(const T*) { return T::minitable(); } template const upb_MiniTable* GetMiniTable(Ptr) { return T::minitable(); } upb_ExtensionRegistry* GetUpbExtensions( const ExtensionRegistry& extension_registry); absl::StatusOr Serialize(const upb_Message* message, const upb_MiniTable* mini_table, upb_Arena* arena, int options); bool HasExtensionOrUnknown(const upb_Message* msg, const upb_MiniTableExtension* eid); bool GetOrPromoteExtension(upb_Message* msg, const upb_MiniTableExtension* eid, upb_Arena* arena, upb_MessageValue* value); void DeepCopy(upb_Message* target, const upb_Message* source, const upb_MiniTable* mini_table, upb_Arena* arena); upb_Message* DeepClone(const upb_Message* source, const upb_MiniTable* mini_table, upb_Arena* arena); absl::Status MoveExtension(upb_Message* message, upb_Arena* message_arena, const upb_MiniTableExtension* ext, upb_Message* extension, upb_Arena* extension_arena); absl::Status SetExtension(upb_Message* message, upb_Arena* message_arena, const upb_MiniTableExtension* ext, const upb_Message* extension); } // namespace internal class ExtensionRegistry { public: ExtensionRegistry( const std::vector& extensions, const upb::Arena& arena) : registry_(upb_ExtensionRegistry_New(arena.ptr())) { if (registry_) { for (const auto& ext_provider : extensions) { const auto* ext = ext_provider->mini_table_ext(); bool success = upb_ExtensionRegistry_AddArray(registry_, &ext, 1); if (!success) { registry_ = nullptr; break; } } } } private: friend upb_ExtensionRegistry* ::hpb::internal::GetUpbExtensions( const ExtensionRegistry& extension_registry); upb_ExtensionRegistry* registry_; }; template > ABSL_MUST_USE_RESULT bool HasExtension( Ptr message, const ::hpb::internal::ExtensionIdentifier& id) { return ::hpb::internal::HasExtensionOrUnknown( ::hpb::internal::GetInternalMsg(message), id.mini_table_ext()); } template > ABSL_MUST_USE_RESULT bool HasExtension( const T* message, const ::hpb::internal::ExtensionIdentifier& id) { return HasExtension(Ptr(message), id); } template , typename = hpb::internal::EnableIfMutableProto> void ClearExtension( Ptr message, const ::hpb::internal::ExtensionIdentifier& id) { static_assert(!std::is_const_v, ""); upb_Message_ClearExtension(hpb::internal::GetInternalMsg(message), id.mini_table_ext()); } template > void ClearExtension( T* message, const ::hpb::internal::ExtensionIdentifier& id) { ClearExtension(Ptr(message), id); } template , typename = hpb::internal::EnableIfMutableProto> absl::Status SetExtension( Ptr message, const ::hpb::internal::ExtensionIdentifier& id, const Extension& value) { static_assert(!std::is_const_v); auto* message_arena = static_cast(message->GetInternalArena()); return ::hpb::internal::SetExtension(hpb::internal::GetInternalMsg(message), message_arena, id.mini_table_ext(), hpb::internal::GetInternalMsg(&value)); } template , typename = hpb::internal::EnableIfMutableProto> absl::Status SetExtension( Ptr message, const ::hpb::internal::ExtensionIdentifier& id, Ptr value) { static_assert(!std::is_const_v); auto* message_arena = static_cast(message->GetInternalArena()); return ::hpb::internal::SetExtension(hpb::internal::GetInternalMsg(message), message_arena, id.mini_table_ext(), hpb::internal::GetInternalMsg(value)); } template , typename = hpb::internal::EnableIfMutableProto> absl::Status SetExtension( Ptr message, const ::hpb::internal::ExtensionIdentifier& id, Extension&& value) { Extension ext = std::move(value); static_assert(!std::is_const_v); auto* message_arena = static_cast(message->GetInternalArena()); auto* extension_arena = static_cast(ext.GetInternalArena()); return ::hpb::internal::MoveExtension(hpb::internal::GetInternalMsg(message), message_arena, id.mini_table_ext(), hpb::internal::GetInternalMsg(&ext), extension_arena); } template > absl::Status SetExtension( T* message, const ::hpb::internal::ExtensionIdentifier& id, const Extension& value) { return ::hpb::SetExtension(Ptr(message), id, value); } template > absl::Status SetExtension( T* message, const ::hpb::internal::ExtensionIdentifier& id, Extension&& value) { return ::hpb::SetExtension(Ptr(message), id, std::forward(value)); } template > absl::Status SetExtension( T* message, const ::hpb::internal::ExtensionIdentifier& id, Ptr value) { return ::hpb::SetExtension(Ptr(message), id, value); } template > absl::StatusOr> GetExtension( Ptr message, const ::hpb::internal::ExtensionIdentifier& id) { // TODO: Fix const correctness issues. upb_MessageValue value; const bool ok = ::hpb::internal::GetOrPromoteExtension( const_cast(::hpb::internal::GetInternalMsg(message)), id.mini_table_ext(), ::hpb::internal::GetArena(message), &value); if (!ok) { return ExtensionNotFoundError( upb_MiniTableExtension_Number(id.mini_table_ext())); } return Ptr(::hpb::internal::CreateMessage( (upb_Message*)value.msg_val, ::hpb::internal::GetArena(message))); } template > absl::StatusOr> GetExtension( const T* message, const ::hpb::internal::ExtensionIdentifier& id) { return GetExtension(Ptr(message), id); } template constexpr uint32_t ExtensionNumber( ::hpb::internal::ExtensionIdentifier id) { return ::hpb::internal::PrivateAccess::GetExtensionNumber(id); } template typename T::Proxy CreateMessage(::hpb::Arena& arena) { return typename T::Proxy(upb_Message_New(T::minitable(), arena.ptr()), arena.ptr()); } template typename T::Proxy CloneMessage(Ptr message, upb_Arena* arena) { return ::hpb::internal::PrivateAccess::Proxy( ::hpb::internal::DeepClone(::hpb::internal::GetInternalMsg(message), T::minitable(), arena), arena); } template void DeepCopy(Ptr source_message, Ptr target_message) { static_assert(!std::is_const_v); ::hpb::internal::DeepCopy( hpb::internal::GetInternalMsg(target_message), hpb::internal::GetInternalMsg(source_message), T::minitable(), static_cast(target_message->GetInternalArena())); } template void DeepCopy(Ptr source_message, T* target_message) { static_assert(!std::is_const_v); DeepCopy(source_message, Ptr(target_message)); } template void DeepCopy(const T* source_message, Ptr target_message) { static_assert(!std::is_const_v); DeepCopy(Ptr(source_message), target_message); } template void DeepCopy(const T* source_message, T* target_message) { static_assert(!std::is_const_v); DeepCopy(Ptr(source_message), Ptr(target_message)); } template void ClearMessage(hpb::internal::PtrOrRaw message) { auto ptr = Ptr(message); auto minitable = hpb::internal::GetMiniTable(ptr); upb_Message_Clear(hpb::internal::GetInternalMsg(ptr), minitable); } template ABSL_MUST_USE_RESULT bool Parse(Ptr message, absl::string_view bytes) { static_assert(!std::is_const_v); upb_Message_Clear(::hpb::internal::GetInternalMsg(message), ::hpb::internal::GetMiniTable(message)); auto* arena = static_cast(message->GetInternalArena()); return upb_Decode(bytes.data(), bytes.size(), ::hpb::internal::GetInternalMsg(message), ::hpb::internal::GetMiniTable(message), /* extreg= */ nullptr, /* options= */ 0, arena) == kUpb_DecodeStatus_Ok; } template ABSL_MUST_USE_RESULT bool Parse( Ptr message, absl::string_view bytes, const ::hpb::ExtensionRegistry& extension_registry) { static_assert(!std::is_const_v); upb_Message_Clear(::hpb::internal::GetInternalMsg(message), ::hpb::internal::GetMiniTable(message)); auto* arena = static_cast(message->GetInternalArena()); return upb_Decode(bytes.data(), bytes.size(), ::hpb::internal::GetInternalMsg(message), ::hpb::internal::GetMiniTable(message), /* extreg= */ ::hpb::internal::GetUpbExtensions(extension_registry), /* options= */ 0, arena) == kUpb_DecodeStatus_Ok; } template ABSL_MUST_USE_RESULT bool Parse( T* message, absl::string_view bytes, const ::hpb::ExtensionRegistry& extension_registry) { static_assert(!std::is_const_v); return Parse(Ptr(message, bytes, extension_registry)); } template ABSL_MUST_USE_RESULT bool Parse(T* message, absl::string_view bytes) { static_assert(!std::is_const_v); upb_Message_Clear(::hpb::internal::GetInternalMsg(message), ::hpb::internal::GetMiniTable(message)); auto* arena = static_cast(message->GetInternalArena()); return upb_Decode(bytes.data(), bytes.size(), ::hpb::internal::GetInternalMsg(message), ::hpb::internal::GetMiniTable(message), /* extreg= */ nullptr, /* options= */ 0, arena) == kUpb_DecodeStatus_Ok; } template absl::StatusOr Parse(absl::string_view bytes, int options = 0) { T message; auto* arena = static_cast(message.GetInternalArena()); upb_DecodeStatus status = upb_Decode(bytes.data(), bytes.size(), message.msg(), ::hpb::internal::GetMiniTable(&message), /* extreg= */ nullptr, /* options= */ 0, arena); if (status == kUpb_DecodeStatus_Ok) { return message; } return MessageDecodeError(status); } template absl::StatusOr Parse(absl::string_view bytes, const ::hpb::ExtensionRegistry& extension_registry, int options = 0) { T message; auto* arena = static_cast(message.GetInternalArena()); upb_DecodeStatus status = upb_Decode(bytes.data(), bytes.size(), message.msg(), ::hpb::internal::GetMiniTable(&message), ::hpb::internal::GetUpbExtensions(extension_registry), /* options= */ 0, arena); if (status == kUpb_DecodeStatus_Ok) { return message; } return MessageDecodeError(status); } template absl::StatusOr Serialize(const T* message, upb::Arena& arena, int options = 0) { return ::hpb::internal::Serialize(::hpb::internal::GetInternalMsg(message), ::hpb::internal::GetMiniTable(message), arena.ptr(), options); } template absl::StatusOr Serialize(Ptr message, upb::Arena& arena, int options = 0) { return ::hpb::internal::Serialize(::hpb::internal::GetInternalMsg(message), ::hpb::internal::GetMiniTable(message), arena.ptr(), options); } } // namespace hpb #endif // PROTOBUF_HPB_HPB_H_