diff --git a/hpb/hpb.cc b/hpb/hpb.cc index 21eada8acb..57163be5bb 100644 --- a/hpb/hpb.cc +++ b/hpb/hpb.cc @@ -7,117 +7,22 @@ #include "google/protobuf/hpb/hpb.h" -#include -#include - #include "absl/status/status.h" -#include "absl/status/statusor.h" #include "absl/strings/str_format.h" #include "absl/strings/string_view.h" #include "google/protobuf/hpb/internal/message_lock.h" +#include "google/protobuf/hpb/status.h" #include "upb/mem/arena.h" #include "upb/message/accessors.h" -#include "upb/message/copy.h" #include "upb/message/message.h" -#include "upb/message/promote.h" -#include "upb/message/value.h" #include "upb/mini_table/extension.h" -#include "upb/mini_table/extension_registry.h" -#include "upb/mini_table/message.h" #include "upb/wire/decode.h" #include "upb/wire/encode.h" namespace hpb { -absl::Status MessageAllocationError(SourceLocation loc) { - return absl::Status(absl::StatusCode::kUnknown, - "Upb message allocation error"); -} - -absl::Status ExtensionNotFoundError(int ext_number, SourceLocation loc) { - return absl::Status(absl::StatusCode::kUnknown, - absl::StrFormat("Extension %d not found", ext_number)); -} - -absl::Status MessageEncodeError(upb_EncodeStatus s, SourceLocation loc) { - return absl::Status(absl::StatusCode::kUnknown, "Encoding error"); -} - -absl::Status MessageDecodeError(upb_DecodeStatus status, SourceLocation loc - -) { - return absl::Status(absl::StatusCode::kUnknown, "Upb message parse error"); -} - namespace internal { -/** - * MessageLock(msg) acquires lock on msg when constructed and releases it when - * destroyed. - */ -class MessageLock { - public: - explicit MessageLock(const upb_Message* msg) : msg_(msg) { - UpbExtensionLocker locker = - upb_extension_locker_global.load(std::memory_order_acquire); - unlocker_ = (locker != nullptr) ? locker(msg) : nullptr; - } - MessageLock(const MessageLock&) = delete; - void operator=(const MessageLock&) = delete; - ~MessageLock() { - if (unlocker_ != nullptr) { - unlocker_(msg_); - } - } - - private: - const upb_Message* msg_; - UpbExtensionUnlocker unlocker_; -}; - -bool HasExtensionOrUnknown(const upb_Message* msg, - const upb_MiniTableExtension* eid) { - MessageLock msg_lock(msg); - if (upb_Message_HasExtension(msg, eid)) return true; - - const int number = upb_MiniTableExtension_Number(eid); - return upb_Message_FindUnknown(msg, number, 0).status == kUpb_FindUnknown_Ok; -} - -bool GetOrPromoteExtension(upb_Message* msg, const upb_MiniTableExtension* eid, - upb_Arena* arena, upb_MessageValue* value) { - MessageLock msg_lock(msg); - upb_GetExtension_Status ext_status = upb_Message_GetOrPromoteExtension( - (upb_Message*)msg, eid, 0, arena, value); - return ext_status == kUpb_GetExtension_Ok; -} - -absl::StatusOr Serialize(const upb_Message* message, - const upb_MiniTable* mini_table, - upb_Arena* arena, int options) { - MessageLock msg_lock(message); - size_t len; - char* ptr; - upb_EncodeStatus status = - upb_Encode(message, mini_table, options, arena, &ptr, &len); - if (status == kUpb_EncodeStatus_Ok) { - return absl::string_view(ptr, len); - } - return MessageEncodeError(status); -} - -void DeepCopy(upb_Message* target, const upb_Message* source, - const upb_MiniTable* mini_table, upb_Arena* arena) { - MessageLock msg_lock(source); - upb_Message_DeepCopy(target, source, mini_table, arena); -} - -upb_Message* DeepClone(const upb_Message* source, - const upb_MiniTable* mini_table, upb_Arena* arena) { - MessageLock msg_lock(source); - return upb_Message_DeepClone(source, mini_table, arena); -} - absl::Status MoveExtension(upb_Message* message, upb_Arena* message_arena, const upb_MiniTableExtension* ext, upb_Message* extension, upb_Arena* extension_arena) { diff --git a/hpb/hpb.h b/hpb/hpb.h index b84b0534ab..2faa6d39eb 100644 --- a/hpb/hpb.h +++ b/hpb/hpb.h @@ -10,7 +10,6 @@ #include #include -#include #include "absl/status/status.h" #include "absl/status/statusor.h" @@ -19,14 +18,13 @@ #include "google/protobuf/hpb/backend/upb/interop.h" #include "google/protobuf/hpb/extension.h" #include "google/protobuf/hpb/internal/internal.h" +#include "google/protobuf/hpb/internal/message_lock.h" #include "google/protobuf/hpb/internal/template_help.h" #include "google/protobuf/hpb/ptr.h" -#include "upb/base/status.hpp" +#include "google/protobuf/hpb/status.h" #include "upb/mem/arena.hpp" -#include "upb/message/copy.h" #include "upb/mini_table/extension.h" #include "upb/wire/decode.h" -#include "upb/wire/encode.h" #ifdef HPB_BACKEND_UPB #include "google/protobuf/hpb/backend/upb/upb.h" @@ -37,44 +35,8 @@ namespace hpb { class ExtensionRegistry; -// This type exists to work around an absl type that has not yet been -// released. -struct SourceLocation { - static SourceLocation current() { return {}; } - absl::string_view file_name() { return ""; } - int line() { return 0; } -}; - -absl::Status MessageAllocationError( - SourceLocation loc = SourceLocation::current()); - -absl::Status ExtensionNotFoundError( - int extension_number, SourceLocation loc = SourceLocation::current()); - -absl::Status MessageDecodeError(upb_DecodeStatus status, - SourceLocation loc = SourceLocation::current()); - -absl::Status MessageEncodeError(upb_EncodeStatus status, - SourceLocation loc = SourceLocation::current()); - namespace internal { -absl::StatusOr Serialize(const upb_Message* message, - const upb_MiniTable* mini_table, - upb_Arena* arena, int options); - -bool HasExtensionOrUnknown(const upb_Message* msg, - const upb_MiniTableExtension* eid); - -bool GetOrPromoteExtension(upb_Message* msg, const upb_MiniTableExtension* eid, - upb_Arena* arena, upb_MessageValue* value); - -void DeepCopy(upb_Message* target, const upb_Message* source, - const upb_MiniTable* mini_table, upb_Arena* arena); - -upb_Message* DeepClone(const upb_Message* source, - const upb_MiniTable* mini_table, upb_Arena* arena); - absl::Status MoveExtension(upb_Message* message, upb_Arena* message_arena, const upb_MiniTableExtension* ext, upb_Message* extension, upb_Arena* extension_arena); @@ -198,11 +160,10 @@ template > GetExtension( Ptr message, const ::hpb::internal::ExtensionIdentifier& id) { - // TODO: Fix const correctness issues. upb_MessageValue value; const bool ok = ::hpb::internal::GetOrPromoteExtension( - const_cast(hpb::interop::upb::GetMessage(message)), - id.mini_table_ext(), hpb::interop::upb::GetArena(message), &value); + hpb::interop::upb::GetMessage(message), id.mini_table_ext(), + hpb::interop::upb::GetArena(message), &value); if (!ok) { return ExtensionNotFoundError( upb_MiniTableExtension_Number(id.mini_table_ext())); diff --git a/hpb/internal/message_lock.cc b/hpb/internal/message_lock.cc index 62f69ce157..bb5a10d451 100644 --- a/hpb/internal/message_lock.cc +++ b/hpb/internal/message_lock.cc @@ -8,9 +8,94 @@ #include "google/protobuf/hpb/internal/message_lock.h" #include +#include +#include + +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "google/protobuf/hpb/status.h" +#include "upb/mem/arena.h" +#include "upb/message/accessors.h" +#include "upb/message/array.h" +#include "upb/message/copy.h" +#include "upb/message/message.h" +#include "upb/message/promote.h" +#include "upb/mini_table/extension.h" +#include "upb/mini_table/message.h" +#include "upb/wire/encode.h" namespace hpb::internal { std::atomic upb_extension_locker_global; +/** + * MessageLock(msg) acquires lock on msg when constructed and releases it when + * destroyed. + */ +class MessageLock { + public: + explicit MessageLock(const upb_Message* msg) : msg_(msg) { + UpbExtensionLocker locker = + upb_extension_locker_global.load(std::memory_order_acquire); + unlocker_ = (locker != nullptr) ? locker(msg) : nullptr; + } + MessageLock(const MessageLock&) = delete; + void operator=(const MessageLock&) = delete; + ~MessageLock() { + if (unlocker_ != nullptr) { + unlocker_(msg_); + } + } + + private: + const upb_Message* msg_; + UpbExtensionUnlocker unlocker_; +}; + +bool HasExtensionOrUnknown(const upb_Message* msg, + const upb_MiniTableExtension* eid) { + MessageLock msg_lock(msg); + if (upb_Message_HasExtension(msg, eid)) return true; + + const uint32_t number = upb_MiniTableExtension_Number(eid); + return upb_Message_FindUnknown(msg, number, 0).status == kUpb_FindUnknown_Ok; +} + +bool GetOrPromoteExtension(const upb_Message* msg, + const upb_MiniTableExtension* eid, upb_Arena* arena, + upb_MessageValue* value) { + // TODO: Fix const correctness issues. + auto mutable_msg = const_cast(msg); + MessageLock msg_lock(mutable_msg); + upb_GetExtension_Status ext_status = + upb_Message_GetOrPromoteExtension(mutable_msg, eid, 0, arena, value); + return ext_status == kUpb_GetExtension_Ok; +} + +absl::StatusOr Serialize(const upb_Message* message, + const upb_MiniTable* mini_table, + upb_Arena* arena, int options) { + MessageLock msg_lock(message); + size_t len; + char* ptr; + upb_EncodeStatus status = + upb_Encode(message, mini_table, options, arena, &ptr, &len); + if (status == kUpb_EncodeStatus_Ok) { + return absl::string_view(ptr, len); + } + return MessageEncodeError(status); +} + +void DeepCopy(upb_Message* target, const upb_Message* source, + const upb_MiniTable* mini_table, upb_Arena* arena) { + MessageLock msg_lock(source); + upb_Message_DeepCopy(target, source, mini_table, arena); +} + +upb_Message* DeepClone(const upb_Message* source, + const upb_MiniTable* mini_table, upb_Arena* arena) { + MessageLock msg_lock(source); + return upb_Message_DeepClone(source, mini_table, arena); +} + } // namespace hpb::internal diff --git a/hpb/internal/message_lock.h b/hpb/internal/message_lock.h index 430d24bb79..8e86c6bc99 100644 --- a/hpb/internal/message_lock.h +++ b/hpb/internal/message_lock.h @@ -10,6 +10,10 @@ #include +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "upb/message/message.h" + namespace hpb::internal { // TODO: Temporary locking api for cross-language @@ -26,6 +30,23 @@ using UpbExtensionLocker = UpbExtensionUnlocker (*)(const void*); // TODO: Expose as function instead of global. extern std::atomic upb_extension_locker_global; +absl::StatusOr Serialize(const upb_Message* message, + const upb_MiniTable* mini_table, + upb_Arena* arena, int options); + +bool HasExtensionOrUnknown(const upb_Message* msg, + const upb_MiniTableExtension* eid); + +bool GetOrPromoteExtension(const upb_Message* msg, + const upb_MiniTableExtension* eid, upb_Arena* arena, + upb_MessageValue* value); + +void DeepCopy(upb_Message* target, const upb_Message* source, + const upb_MiniTable* mini_table, upb_Arena* arena); + +upb_Message* DeepClone(const upb_Message* source, + const upb_MiniTable* mini_table, upb_Arena* arena); + } // namespace hpb::internal #endif // PROTOBUF_HPB_EXTENSION_LOCK_H_ diff --git a/hpb/status.cc b/hpb/status.cc new file mode 100644 index 0000000000..0a1558e992 --- /dev/null +++ b/hpb/status.cc @@ -0,0 +1,37 @@ +// Protocol Buffers - Google's data interchange format +// Copyright 2024 Google LLC. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file or at +// https://developers.google.com/open-source/licenses/bsd + +#include "google/protobuf/hpb/status.h" + +#include "absl/status/status.h" +#include "absl/strings/str_format.h" +#include "absl/types/source_location.h" +#include "upb/wire/decode.h" +#include "upb/wire/encode.h" + +namespace hpb { +absl::Status MessageAllocationError(SourceLocation loc) { + return absl::Status(absl::StatusCode::kUnknown, + "Upb message allocation error"); +} + +absl::Status ExtensionNotFoundError(int ext_number, SourceLocation loc) { + return absl::Status(absl::StatusCode::kUnknown, + absl::StrFormat("Extension %d not found", ext_number)); +} + +absl::Status MessageEncodeError(upb_EncodeStatus s, SourceLocation loc) { + return absl::Status(absl::StatusCode::kUnknown, "Encoding error"); +} + +absl::Status MessageDecodeError(upb_DecodeStatus status, SourceLocation loc + +) { + return absl::Status(absl::StatusCode::kUnknown, "Upb message parse error"); +} + +} // namespace hpb diff --git a/hpb/status.h b/hpb/status.h new file mode 100644 index 0000000000..5c373098fe --- /dev/null +++ b/hpb/status.h @@ -0,0 +1,39 @@ +// Protocol Buffers - Google's data interchange format +// Copyright 2024 Google LLC. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file or at +// https://developers.google.com/open-source/licenses/bsd + +#ifndef GOOGLE_PROTOBUF_HPB_STATUS_H__ +#define GOOGLE_PROTOBUF_HPB_STATUS_H__ + +#include "absl/status/status.h" +#include "absl/types/source_location.h" +#include "upb/wire/decode.h" +#include "upb/wire/encode.h" + +namespace hpb { + +// This type exists to work around an absl type that has not yet been +// released. +struct SourceLocation { + static SourceLocation current() { return {}; } + absl::string_view file_name() { return ""; } + int line() { return 0; } +}; + +absl::Status MessageEncodeError(upb_EncodeStatus status, + SourceLocation loc = SourceLocation::current()); + +absl::Status MessageAllocationError( + SourceLocation loc = SourceLocation::current()); + +absl::Status ExtensionNotFoundError( + int extension_number, SourceLocation loc = SourceLocation::current()); + +absl::Status MessageDecodeError(upb_DecodeStatus status, + SourceLocation loc = SourceLocation::current()); +} // namespace hpb + +#endif // GOOGLE_PROTOBUF_HPB_STATUS_H__