#include "rust/cpp_kernel/map.h" #include #include #include #include #include #include "absl/log/absl_log.h" #include "google/protobuf/map.h" #include "google/protobuf/message.h" #include "google/protobuf/message_lite.h" #include "rust/cpp_kernel/strings.h" namespace google { namespace protobuf { namespace rust { namespace { template struct FromViewType { using type = T; }; template <> struct FromViewType { using type = std::string; }; template using KeyMap = internal::KeyMapBase< internal::KeyForBase::type>>; internal::MapNodeSizeInfoT GetSizeInfo(size_t key_size, const google::protobuf::MessageLite* value) { // Each map node consists of a NodeBase followed by a std::pair. // We need to compute the offset of the value and the total size of the node. size_t node_and_key_size = sizeof(internal::NodeBase) + key_size; uint16_t value_size; uint8_t value_alignment; internal::RustMapHelper::GetSizeAndAlignment(value, &value_size, &value_alignment); // Round node_and_key_size up to the nearest multiple of value_alignment. uint16_t offset = (((node_and_key_size - 1) / value_alignment) + 1) * value_alignment; return internal::RustMapHelper::MakeSizeInfo(offset + value_size, offset); } template internal::MapNodeSizeInfoT GetSizeInfo(const google::protobuf::MessageLite* value) { return GetSizeInfo(sizeof(Key), value); } template void DestroyMapNode(internal::UntypedMapBase* m, internal::NodeBase* node, internal::MapNodeSizeInfoT size_info) { if constexpr (std::is_same::value) { static_cast(node->GetVoidKey())->~basic_string(); } internal::RustMapHelper::DestroyMessage( static_cast(node->GetVoidValue(size_info))); internal::RustMapHelper::DeallocNode(m, node, size_info); } template bool Insert(internal::UntypedMapBase* m, Key key, MessageLite* value) { internal::MapNodeSizeInfoT size_info = GetSizeInfo::type>(value); internal::NodeBase* node = internal::RustMapHelper::AllocNode(m, size_info); if constexpr (std::is_same::value) { new (node->GetVoidKey()) std::string(key.ptr, key.len); } else { *static_cast(node->GetVoidKey()) = key; } MessageLite* new_msg = internal::RustMapHelper::PlacementNew( value, node->GetVoidValue(size_info)); auto* full_msg = DynamicCastMessage(new_msg); // If we are working with a full (non-lite) proto, we reflectively swap the // value into place. Otherwise, we have to perform a copy. if (full_msg != nullptr) { full_msg->GetReflection()->Swap(full_msg, DynamicCastMessage(value)); } else { new_msg->CheckTypeAndMergeFrom(*value); } node = internal::RustMapHelper::InsertOrReplaceNode( static_cast*>(m), node); if (node == nullptr) { return true; } DestroyMapNode(m, node, size_info); return false; } template ::value>::type> internal::RustMapHelper::NodeAndBucket FindHelper(Map* m, Key key) { return internal::RustMapHelper::FindHelper( m, static_cast>(key)); } template internal::RustMapHelper::NodeAndBucket FindHelper(Map* m, google::protobuf::rust::PtrAndLen key) { return internal::RustMapHelper::FindHelper( m, absl::string_view(key.ptr, key.len)); } template bool Get(internal::UntypedMapBase* m, const google::protobuf::MessageLite* prototype, Key key, MessageLite** value) { internal::MapNodeSizeInfoT size_info = GetSizeInfo::type>(prototype); auto* map_base = static_cast*>(m); internal::RustMapHelper::NodeAndBucket result = FindHelper(map_base, key); if (result.node == nullptr) { return false; } *value = static_cast(result.node->GetVoidValue(size_info)); return true; } template bool Remove(internal::UntypedMapBase* m, const google::protobuf::MessageLite* prototype, Key key) { internal::MapNodeSizeInfoT size_info = GetSizeInfo::type>(prototype); auto* map_base = static_cast*>(m); internal::RustMapHelper::NodeAndBucket result = FindHelper(map_base, key); if (result.node == nullptr) { return false; } internal::RustMapHelper::EraseNoDestroy(map_base, result.bucket, result.node); DestroyMapNode(m, result.node, size_info); return true; } template void IterGet(const internal::UntypedMapIterator* iter, const google::protobuf::MessageLite* prototype, Key* key, MessageLite** value) { internal::MapNodeSizeInfoT size_info = GetSizeInfo::type>(prototype); internal::NodeBase* node = iter->node_; if constexpr (std::is_same::value) { const std::string* s = static_cast(node->GetVoidKey()); *key = PtrAndLen{s->data(), s->size()}; } else { *key = *static_cast(node->GetVoidKey()); } *value = static_cast(node->GetVoidValue(size_info)); } // LINT.IfChange(map_key_category) enum class MapKeyCategory : uint8_t { kOneByte = 0, kFourBytes = 1, kEightBytes = 2, kStdString = 3, }; // LINT.ThenChange(//depot/google3/third_party/protobuf/rust/cpp.rs:map_key_category) size_t KeySize(MapKeyCategory category) { switch (category) { case MapKeyCategory::kOneByte: return 1; case MapKeyCategory::kFourBytes: return 4; case MapKeyCategory::kEightBytes: return 8; case MapKeyCategory::kStdString: return sizeof(std::string); default: ABSL_DLOG(FATAL) << "Unexpected value of MapKeyCategory enum"; } } void ClearMap(internal::UntypedMapBase* m, MapKeyCategory category, bool reset_table, const google::protobuf::MessageLite* prototype) { internal::MapNodeSizeInfoT size_info = GetSizeInfo(KeySize(category), prototype); if (internal::RustMapHelper::IsGlobalEmptyTable(m)) return; uint8_t bits = internal::RustMapHelper::kValueIsProto; if (category == MapKeyCategory::kStdString) { bits |= internal::RustMapHelper::kKeyIsString; } internal::RustMapHelper::ClearTable( m, internal::RustMapHelper::ClearInput{size_info, bits, reset_table, /* destroy_node = */ nullptr}); } } // namespace } // namespace rust } // namespace protobuf } // namespace google extern "C" { void proto2_rust_thunk_UntypedMapIterator_increment( google::protobuf::internal::UntypedMapIterator* iter) { iter->PlusPlus(); } google::protobuf::internal::UntypedMapBase* proto2_rust_map_new() { return new google::protobuf::internal::UntypedMapBase(/* arena = */ nullptr); } void proto2_rust_map_free(google::protobuf::internal::UntypedMapBase* m, google::protobuf::rust::MapKeyCategory category, const google::protobuf::MessageLite* prototype) { google::protobuf::rust::ClearMap(m, category, /* reset_table = */ false, prototype); delete m; } void proto2_rust_map_clear(google::protobuf::internal::UntypedMapBase* m, google::protobuf::rust::MapKeyCategory category, const google::protobuf::MessageLite* prototype) { google::protobuf::rust::ClearMap(m, category, /* reset_table = */ true, prototype); } size_t proto2_rust_map_size(google::protobuf::internal::UntypedMapBase* m) { return m->size(); } google::protobuf::internal::UntypedMapIterator proto2_rust_map_iter( google::protobuf::internal::UntypedMapBase* m) { return m->begin(); } #define DEFINE_KEY_SPECIFIC_MAP_OPERATIONS(cpp_type, suffix) \ bool proto2_rust_map_insert_##suffix(google::protobuf::internal::UntypedMapBase* m, \ cpp_type key, \ google::protobuf::MessageLite* value) { \ return google::protobuf::rust::Insert(m, key, value); \ } \ \ bool proto2_rust_map_get_##suffix(google::protobuf::internal::UntypedMapBase* m, \ const google::protobuf::MessageLite* prototype, \ cpp_type key, \ google::protobuf::MessageLite** value) { \ return google::protobuf::rust::Get(m, prototype, key, value); \ } \ \ bool proto2_rust_map_remove_##suffix(google::protobuf::internal::UntypedMapBase* m, \ const google::protobuf::MessageLite* prototype, \ cpp_type key) { \ return google::protobuf::rust::Remove(m, prototype, key); \ } \ \ void proto2_rust_map_iter_get_##suffix( \ const google::protobuf::internal::UntypedMapIterator* iter, \ const google::protobuf::MessageLite* prototype, cpp_type* key, \ google::protobuf::MessageLite** value) { \ return google::protobuf::rust::IterGet(iter, prototype, key, value); \ } DEFINE_KEY_SPECIFIC_MAP_OPERATIONS(int32_t, i32) DEFINE_KEY_SPECIFIC_MAP_OPERATIONS(uint32_t, u32) DEFINE_KEY_SPECIFIC_MAP_OPERATIONS(int64_t, i64) DEFINE_KEY_SPECIFIC_MAP_OPERATIONS(uint64_t, u64) DEFINE_KEY_SPECIFIC_MAP_OPERATIONS(bool, bool) DEFINE_KEY_SPECIFIC_MAP_OPERATIONS(google::protobuf::rust::PtrAndLen, ProtoString) __PB_RUST_EXPOSE_SCALAR_MAP_METHODS_FOR_VALUE_TYPE(int32_t, i32, int32_t, int32_t, value, cpp_value); __PB_RUST_EXPOSE_SCALAR_MAP_METHODS_FOR_VALUE_TYPE(uint32_t, u32, uint32_t, uint32_t, value, cpp_value); __PB_RUST_EXPOSE_SCALAR_MAP_METHODS_FOR_VALUE_TYPE(float, f32, float, float, value, cpp_value); __PB_RUST_EXPOSE_SCALAR_MAP_METHODS_FOR_VALUE_TYPE(double, f64, double, double, value, cpp_value); __PB_RUST_EXPOSE_SCALAR_MAP_METHODS_FOR_VALUE_TYPE(bool, bool, bool, bool, value, cpp_value); __PB_RUST_EXPOSE_SCALAR_MAP_METHODS_FOR_VALUE_TYPE(uint64_t, u64, uint64_t, uint64_t, value, cpp_value); __PB_RUST_EXPOSE_SCALAR_MAP_METHODS_FOR_VALUE_TYPE(int64_t, i64, int64_t, int64_t, value, cpp_value); __PB_RUST_EXPOSE_SCALAR_MAP_METHODS_FOR_VALUE_TYPE( std::string, ProtoBytes, google::protobuf::rust::PtrAndLen, std::string*, std::move(*value), (google::protobuf::rust::PtrAndLen{cpp_value.data(), cpp_value.size()})); __PB_RUST_EXPOSE_SCALAR_MAP_METHODS_FOR_VALUE_TYPE( std::string, ProtoString, google::protobuf::rust::PtrAndLen, std::string*, std::move(*value), (google::protobuf::rust::PtrAndLen{cpp_value.data(), cpp_value.size()})); } // extern "C"