#include "google/protobuf/map.h" #include #include #include #include #include #include #include "absl/log/absl_log.h" #include "google/protobuf/message.h" #include "google/protobuf/message_lite.h" #include "rust/cpp_kernel/strings.h" namespace google { namespace protobuf { namespace rust { namespace { // LINT.IfChange(map_ffi) enum class MapValueTag : uint8_t { kBool, kU32, kU64, kString, kMessage, }; struct MapValue { MapValueTag tag; union { bool b; uint32_t u32; uint64_t u64; std::string* s; google::protobuf::MessageLite* message; }; }; // LINT.ThenChange(//depot/google3/third_party/protobuf/rust/cpp.rs:map_ffi) template struct FromViewType { using type = T; }; template <> struct FromViewType { using type = std::string; }; template using KeyMap = internal::KeyMapBase< internal::KeyForBase::type>>; void GetSizeAndAlignment(MapValue value, uint16_t* size, uint8_t* alignment) { switch (value.tag) { case MapValueTag::kBool: *size = sizeof(bool); *alignment = alignof(bool); break; case MapValueTag::kU32: *size = sizeof(uint32_t); *alignment = alignof(uint32_t); break; case MapValueTag::kU64: *size = sizeof(uint64_t); *alignment = alignof(uint64_t); break; case MapValueTag::kString: *size = sizeof(std::string); *alignment = alignof(std::string); break; case MapValueTag::kMessage: internal::RustMapHelper::GetSizeAndAlignment(value.message, size, alignment); break; default: ABSL_DLOG(FATAL) << "Unexpected value of MapValue"; } } internal::MapNodeSizeInfoT GetSizeInfo(size_t key_size, MapValue value) { // Each map node consists of a NodeBase followed by a std::pair. // We need to compute the offset of the value and the total size of the node. size_t node_and_key_size = sizeof(internal::NodeBase) + key_size; uint16_t value_size; uint8_t value_alignment; GetSizeAndAlignment(value, &value_size, &value_alignment); // Round node_and_key_size up to the nearest multiple of value_alignment. uint16_t offset = (((node_and_key_size - 1) / value_alignment) + 1) * value_alignment; size_t overall_alignment = std::max(alignof(internal::NodeBase), static_cast(value_alignment)); // Round up size to nearest multiple of overall_alignment. size_t overall_size = (((offset + value_size - 1) / overall_alignment) + 1) * overall_alignment; return internal::RustMapHelper::MakeSizeInfo(overall_size, offset); } template void DestroyMapNode(internal::UntypedMapBase* m, internal::NodeBase* node, internal::MapNodeSizeInfoT size_info, bool destroy_message) { if constexpr (std::is_same::value) { static_cast(node->GetVoidKey())->~basic_string(); } if (destroy_message) { internal::RustMapHelper::DestroyMessage( static_cast(node->GetVoidValue(size_info))); } internal::RustMapHelper::DeallocNode(m, node, size_info); } void InitializeMessageValue(void* raw_ptr, MessageLite* msg) { MessageLite* new_msg = internal::RustMapHelper::PlacementNew(msg, raw_ptr); auto* full_msg = DynamicCastMessage(new_msg); // If we are working with a full (non-lite) proto, we reflectively swap the // value into place. Otherwise, we have to perform a copy. if (full_msg != nullptr) { full_msg->GetReflection()->Swap(full_msg, DynamicCastMessage(msg)); } else { new_msg->CheckTypeAndMergeFrom(*msg); } delete msg; } template bool Insert(internal::UntypedMapBase* m, Key key, MapValue value) { internal::MapNodeSizeInfoT size_info = GetSizeInfo(sizeof(typename FromViewType::type), value); internal::NodeBase* node = internal::RustMapHelper::AllocNode(m, size_info); if constexpr (std::is_same::value) { new (node->GetVoidKey()) std::string(key.ptr, key.len); } else { *static_cast(node->GetVoidKey()) = key; } void* value_ptr = node->GetVoidValue(size_info); switch (value.tag) { case MapValueTag::kBool: *static_cast(value_ptr) = value.b; break; case MapValueTag::kU32: *static_cast(value_ptr) = value.u32; break; case MapValueTag::kU64: *static_cast(value_ptr) = value.u64; break; case MapValueTag::kString: new (value_ptr) std::string(std::move(*value.s)); delete value.s; break; case MapValueTag::kMessage: InitializeMessageValue(value_ptr, value.message); break; default: ABSL_DLOG(FATAL) << "Unexpected value of MapValue"; } node = internal::RustMapHelper::InsertOrReplaceNode( static_cast*>(m), node); if (node == nullptr) { return true; } DestroyMapNode(m, node, size_info, value.tag == MapValueTag::kMessage); return false; } template ::value>::type> internal::RustMapHelper::NodeAndBucket FindHelper(Map* m, Key key) { return internal::RustMapHelper::FindHelper( m, static_cast>(key)); } template internal::RustMapHelper::NodeAndBucket FindHelper(Map* m, google::protobuf::rust::PtrAndLen key) { return internal::RustMapHelper::FindHelper( m, absl::string_view(key.ptr, key.len)); } void PopulateMapValue(MapValueTag tag, void* data, MapValue& output) { output.tag = tag; switch (tag) { case MapValueTag::kBool: output.b = *static_cast(data); break; case MapValueTag::kU32: output.u32 = *static_cast(data); break; case MapValueTag::kU64: output.u64 = *static_cast(data); break; case MapValueTag::kString: output.s = static_cast(data); break; case MapValueTag::kMessage: output.message = static_cast(data); break; default: ABSL_DLOG(FATAL) << "Unexpected MapValueTag"; } } template bool Get(internal::UntypedMapBase* m, MapValue prototype, Key key, MapValue* value) { internal::MapNodeSizeInfoT size_info = GetSizeInfo(sizeof(typename FromViewType::type), prototype); auto* map_base = static_cast*>(m); internal::RustMapHelper::NodeAndBucket result = FindHelper(map_base, key); if (result.node == nullptr) { return false; } PopulateMapValue(prototype.tag, result.node->GetVoidValue(size_info), *value); return true; } template bool Remove(internal::UntypedMapBase* m, MapValue prototype, Key key) { internal::MapNodeSizeInfoT size_info = GetSizeInfo(sizeof(typename FromViewType::type), prototype); auto* map_base = static_cast*>(m); internal::RustMapHelper::NodeAndBucket result = FindHelper(map_base, key); if (result.node == nullptr) { return false; } internal::RustMapHelper::EraseNoDestroy(map_base, result.bucket, result.node); DestroyMapNode(m, result.node, size_info, prototype.tag == MapValueTag::kMessage); return true; } template void IterGet(const internal::UntypedMapIterator* iter, MapValue prototype, Key* key, MapValue* value) { internal::MapNodeSizeInfoT size_info = GetSizeInfo(sizeof(typename FromViewType::type), prototype); internal::NodeBase* node = iter->node_; if constexpr (std::is_same::value) { const std::string* s = static_cast(node->GetVoidKey()); *key = PtrAndLen{s->data(), s->size()}; } else { *key = *static_cast(node->GetVoidKey()); } PopulateMapValue(prototype.tag, node->GetVoidValue(size_info), *value); } // Returns the size of the key in the map entry, given the key used for FFI. // The map entry key and FFI key are always the same, except in the case of // string and bytes. template size_t KeySize() { if constexpr (std::is_same::value) { return sizeof(std::string); } else { return sizeof(Key); } } template void ClearMap(internal::UntypedMapBase* m, bool reset_table, MapValue prototype) { internal::MapNodeSizeInfoT size_info = GetSizeInfo(KeySize(), prototype); if (internal::RustMapHelper::IsGlobalEmptyTable(m)) return; uint8_t bits = 0; if constexpr (std::is_same::value) { bits |= internal::RustMapHelper::kKeyIsString; } if (prototype.tag == MapValueTag::kString) { bits |= internal::RustMapHelper::kValueIsString; } else if (prototype.tag == MapValueTag::kMessage) { bits |= internal::RustMapHelper::kValueIsProto; } internal::RustMapHelper::ClearTable( m, internal::RustMapHelper::ClearInput{size_info, bits, reset_table, /* destroy_node = */ nullptr}); } } // namespace } // namespace rust } // namespace protobuf } // namespace google extern "C" { void proto2_rust_thunk_UntypedMapIterator_increment( google::protobuf::internal::UntypedMapIterator* iter) { iter->PlusPlus(); } google::protobuf::internal::UntypedMapBase* proto2_rust_map_new() { return new google::protobuf::internal::UntypedMapBase(/* arena = */ nullptr); } size_t proto2_rust_map_size(google::protobuf::internal::UntypedMapBase* m) { return m->size(); } google::protobuf::internal::UntypedMapIterator proto2_rust_map_iter( google::protobuf::internal::UntypedMapBase* m) { return m->begin(); } #define DEFINE_KEY_SPECIFIC_MAP_OPERATIONS(cpp_type, suffix) \ void proto2_rust_map_free_##suffix(google::protobuf::internal::UntypedMapBase* m, \ google::protobuf::rust::MapValue prototype) { \ google::protobuf::rust::ClearMap(m, /* reset_table = */ false, prototype); \ delete m; \ } \ void proto2_rust_map_clear_##suffix(google::protobuf::internal::UntypedMapBase* m, \ google::protobuf::rust::MapValue prototype) { \ google::protobuf::rust::ClearMap(m, /* reset_table = */ true, prototype); \ } \ bool proto2_rust_map_insert_##suffix(google::protobuf::internal::UntypedMapBase* m, \ cpp_type key, \ google::protobuf::rust::MapValue value) { \ return google::protobuf::rust::Insert(m, key, value); \ } \ \ bool proto2_rust_map_get_##suffix( \ google::protobuf::internal::UntypedMapBase* m, google::protobuf::rust::MapValue prototype, \ cpp_type key, google::protobuf::rust::MapValue* value) { \ return google::protobuf::rust::Get(m, prototype, key, value); \ } \ \ bool proto2_rust_map_remove_##suffix(google::protobuf::internal::UntypedMapBase* m, \ google::protobuf::rust::MapValue prototype, \ cpp_type key) { \ return google::protobuf::rust::Remove(m, prototype, key); \ } \ \ void proto2_rust_map_iter_get_##suffix( \ const google::protobuf::internal::UntypedMapIterator* iter, \ google::protobuf::rust::MapValue prototype, cpp_type* key, \ google::protobuf::rust::MapValue* value) { \ return google::protobuf::rust::IterGet(iter, prototype, key, value); \ } DEFINE_KEY_SPECIFIC_MAP_OPERATIONS(int32_t, i32) DEFINE_KEY_SPECIFIC_MAP_OPERATIONS(uint32_t, u32) DEFINE_KEY_SPECIFIC_MAP_OPERATIONS(int64_t, i64) DEFINE_KEY_SPECIFIC_MAP_OPERATIONS(uint64_t, u64) DEFINE_KEY_SPECIFIC_MAP_OPERATIONS(bool, bool) DEFINE_KEY_SPECIFIC_MAP_OPERATIONS(google::protobuf::rust::PtrAndLen, ProtoString) } // extern "C"