diff --git a/rust/cpp.rs b/rust/cpp.rs index 661b3c5613..460910463c 100644 --- a/rust/cpp.rs +++ b/rust/cpp.rs @@ -110,11 +110,6 @@ extern "C" { pub fn proto2_rust_Message_merge_from(dst: RawMessage, src: RawMessage) -> bool; } -/// An opaque type matching MapNodeSizeInfoT from C++. -#[doc(hidden)] -#[repr(transparent)] -pub struct MapNodeSizeInfo(pub i32); - impl Drop for InnerProtoString { fn drop(&mut self) { // SAFETY: `self.owned_ptr` points to a valid std::string object. @@ -767,36 +762,45 @@ impl UntypedMapIterator { } } +// This enum is used to pass some information about the key type of a map to +// C++. The main purpose is to indicate the size of the key so that we can +// determine the correct size and offset information of map entries on the C++ +// side. We also rely on it to indicate whether the key is a string or not. #[doc(hidden)] -#[repr(transparent)] -pub struct MapNodeSizeInfoIndex(i32); +#[repr(u8)] +pub enum MapKeyCategory { + OneByte, + FourBytes, + EightBytes, + StdString, +} #[doc(hidden)] -pub trait MapNodeSizeInfoIndexForType { - const SIZE_INFO_INDEX: MapNodeSizeInfoIndex; +pub trait MapKey { + const CATEGORY: MapKeyCategory; } -macro_rules! generate_map_node_size_info_mapping { - ( $($key:ty, $index:expr;)* ) => { +macro_rules! generate_map_key_impl { + ( $($key:ty, $category:expr;)* ) => { $( - impl MapNodeSizeInfoIndexForType for $key { - const SIZE_INFO_INDEX: MapNodeSizeInfoIndex = MapNodeSizeInfoIndex($index); + impl MapKey for $key { + const CATEGORY: MapKeyCategory = $category; } )* } } -// LINT.IfChange(size_info_mapping) -generate_map_node_size_info_mapping!( - i32, 0; - u32, 0; - i64, 1; - u64, 1; - bool, 2; - ProtoString, 3; +// LINT.IfChange(map_key_category) +generate_map_key_impl!( + bool, MapKeyCategory::OneByte; + i32, MapKeyCategory::FourBytes; + u32, MapKeyCategory::FourBytes; + i64, MapKeyCategory::EightBytes; + u64, MapKeyCategory::EightBytes; + ProtoString, MapKeyCategory::StdString; ); -// LINT.ThenChange(//depot/google3/third_party/protobuf/compiler/rust/message. -// cc:size_info_mapping) +// LINT.ThenChange(//depot/google3/third_party/protobuf/rust/cpp_kernel/map.cc: +// map_key_category) macro_rules! impl_map_primitives { (@impl $(($rust_type:ty, $cpp_type:ty) => [ @@ -809,23 +813,22 @@ macro_rules! impl_map_primitives { extern "C" { pub fn $insert_thunk( m: RawMap, - size_info: MapNodeSizeInfo, key: $cpp_type, value: RawMessage, ) -> bool; pub fn $get_thunk( m: RawMap, - size_info: MapNodeSizeInfo, + prototype: RawMessage, key: $cpp_type, value: *mut RawMessage, ) -> bool; pub fn $iter_get_thunk( iter: &mut UntypedMapIterator, - size_info: MapNodeSizeInfo, + prototype: RawMessage, key: *mut $cpp_type, value: *mut RawMessage, ); - pub fn $remove_thunk(m: RawMap, size_info: MapNodeSizeInfo, key: $cpp_type) -> bool; + pub fn $remove_thunk(m: RawMap, prototype: RawMessage, key: $cpp_type) -> bool; } )* }; @@ -856,8 +859,8 @@ extern "C" { fn proto2_rust_thunk_UntypedMapIterator_increment(iter: &mut UntypedMapIterator); pub fn proto2_rust_map_new() -> RawMap; - pub fn proto2_rust_map_free(m: RawMap, key_is_string: bool, size_info: MapNodeSizeInfo); - pub fn proto2_rust_map_clear(m: RawMap, key_is_string: bool, size_info: MapNodeSizeInfo); + pub fn proto2_rust_map_free(m: RawMap, category: MapKeyCategory, prototype: RawMessage); + pub fn proto2_rust_map_clear(m: RawMap, category: MapKeyCategory, prototype: RawMessage); pub fn proto2_rust_map_size(m: RawMap) -> usize; pub fn proto2_rust_map_iter(m: RawMap) -> UntypedMapIterator; } diff --git a/rust/cpp_kernel/map.cc b/rust/cpp_kernel/map.cc index 45b996b92d..7925d2073c 100644 --- a/rust/cpp_kernel/map.cc +++ b/rust/cpp_kernel/map.cc @@ -6,6 +6,7 @@ #include #include +#include "absl/log/absl_log.h" #include "google/protobuf/map.h" #include "google/protobuf/message.h" #include "google/protobuf/message_lite.h" @@ -30,6 +31,26 @@ template using KeyMap = internal::KeyMapBase< internal::KeyForBase::type>>; +internal::MapNodeSizeInfoT GetSizeInfo(size_t key_size, + const google::protobuf::MessageLite* value) { + // Each map node consists of a NodeBase followed by a std::pair. + // We need to compute the offset of the value and the total size of the node. + size_t node_and_key_size = sizeof(internal::NodeBase) + key_size; + uint16_t value_size; + uint8_t value_alignment; + internal::RustMapHelper::GetSizeAndAlignment(value, &value_size, + &value_alignment); + // Round node_and_key_size up to the nearest multiple of value_alignment. + uint16_t offset = + (((node_and_key_size - 1) / value_alignment) + 1) * value_alignment; + return internal::RustMapHelper::MakeSizeInfo(offset + value_size, offset); +} + +template +internal::MapNodeSizeInfoT GetSizeInfo(const google::protobuf::MessageLite* value) { + return GetSizeInfo(sizeof(Key), value); +} + template void DestroyMapNode(internal::UntypedMapBase* m, internal::NodeBase* node, internal::MapNodeSizeInfoT size_info) { @@ -42,8 +63,9 @@ void DestroyMapNode(internal::UntypedMapBase* m, internal::NodeBase* node, } template -bool Insert(internal::UntypedMapBase* m, internal::MapNodeSizeInfoT size_info, - Key key, MessageLite* value) { +bool Insert(internal::UntypedMapBase* m, Key key, MessageLite* value) { + internal::MapNodeSizeInfoT size_info = + GetSizeInfo::type>(value); internal::NodeBase* node = internal::RustMapHelper::AllocNode(m, size_info); if constexpr (std::is_same::value) { new (node->GetVoidKey()) std::string(key.ptr, key.len); @@ -89,8 +111,10 @@ internal::RustMapHelper::NodeAndBucket FindHelper(Map* m, } template -bool Get(internal::UntypedMapBase* m, internal::MapNodeSizeInfoT size_info, +bool Get(internal::UntypedMapBase* m, const google::protobuf::MessageLite* prototype, Key key, MessageLite** value) { + internal::MapNodeSizeInfoT size_info = + GetSizeInfo::type>(prototype); auto* map_base = static_cast*>(m); internal::RustMapHelper::NodeAndBucket result = FindHelper(map_base, key); if (result.node == nullptr) { @@ -101,8 +125,10 @@ bool Get(internal::UntypedMapBase* m, internal::MapNodeSizeInfoT size_info, } template -bool Remove(internal::UntypedMapBase* m, internal::MapNodeSizeInfoT size_info, +bool Remove(internal::UntypedMapBase* m, const google::protobuf::MessageLite* prototype, Key key) { + internal::MapNodeSizeInfoT size_info = + GetSizeInfo::type>(prototype); auto* map_base = static_cast*>(m); internal::RustMapHelper::NodeAndBucket result = FindHelper(map_base, key); if (result.node == nullptr) { @@ -115,8 +141,10 @@ bool Remove(internal::UntypedMapBase* m, internal::MapNodeSizeInfoT size_info, template void IterGet(const internal::UntypedMapIterator* iter, - internal::MapNodeSizeInfoT size_info, Key* key, + const google::protobuf::MessageLite* prototype, Key* key, MessageLite** value) { + internal::MapNodeSizeInfoT size_info = + GetSizeInfo::type>(prototype); internal::NodeBase* node = iter->node_; if constexpr (std::is_same::value) { const std::string* s = static_cast(node->GetVoidKey()); @@ -127,11 +155,37 @@ void IterGet(const internal::UntypedMapIterator* iter, *value = static_cast(node->GetVoidValue(size_info)); } -void ClearMap(internal::UntypedMapBase* m, internal::MapNodeSizeInfoT size_info, - bool key_is_string, bool reset_table) { +// LINT.IfChange(map_key_category) +enum class MapKeyCategory : uint8_t { + kOneByte = 0, + kFourBytes = 1, + kEightBytes = 2, + kStdString = 3, +}; +// LINT.ThenChange(//depot/google3/third_party/protobuf/rust/cpp.rs:map_key_category) + +size_t KeySize(MapKeyCategory category) { + switch (category) { + case MapKeyCategory::kOneByte: + return 1; + case MapKeyCategory::kFourBytes: + return 4; + case MapKeyCategory::kEightBytes: + return 8; + case MapKeyCategory::kStdString: + return sizeof(std::string); + default: + ABSL_DLOG(FATAL) << "Unexpected value of MapKeyCategory enum"; + } +} + +void ClearMap(internal::UntypedMapBase* m, MapKeyCategory category, + bool reset_table, const google::protobuf::MessageLite* prototype) { + internal::MapNodeSizeInfoT size_info = + GetSizeInfo(KeySize(category), prototype); if (internal::RustMapHelper::IsGlobalEmptyTable(m)) return; uint8_t bits = internal::RustMapHelper::kValueIsProto; - if (key_is_string) { + if (category == MapKeyCategory::kStdString) { bits |= internal::RustMapHelper::kKeyIsString; } internal::RustMapHelper::ClearTable( @@ -156,17 +210,16 @@ google::protobuf::internal::UntypedMapBase* proto2_rust_map_new() { } void proto2_rust_map_free(google::protobuf::internal::UntypedMapBase* m, - bool key_is_string, - google::protobuf::internal::MapNodeSizeInfoT size_info) { - google::protobuf::rust::ClearMap(m, size_info, key_is_string, - /* reset_table = */ false); + google::protobuf::rust::MapKeyCategory category, + const google::protobuf::MessageLite* prototype) { + google::protobuf::rust::ClearMap(m, category, /* reset_table = */ false, prototype); delete m; } void proto2_rust_map_clear(google::protobuf::internal::UntypedMapBase* m, - bool key_is_string, - google::protobuf::internal::MapNodeSizeInfoT size_info) { - google::protobuf::rust::ClearMap(m, size_info, key_is_string, /* reset_table = */ true); + google::protobuf::rust::MapKeyCategory category, + const google::protobuf::MessageLite* prototype) { + google::protobuf::rust::ClearMap(m, category, /* reset_table = */ true, prototype); } size_t proto2_rust_map_size(google::protobuf::internal::UntypedMapBase* m) { @@ -178,32 +231,31 @@ google::protobuf::internal::UntypedMapIterator proto2_rust_map_iter( return m->begin(); } -#define DEFINE_KEY_SPECIFIC_MAP_OPERATIONS(cpp_type, suffix) \ - bool proto2_rust_map_insert_##suffix( \ - google::protobuf::internal::UntypedMapBase* m, \ - google::protobuf::internal::MapNodeSizeInfoT size_info, cpp_type key, \ - google::protobuf::MessageLite* value) { \ - return google::protobuf::rust::Insert(m, size_info, key, value); \ - } \ - \ - bool proto2_rust_map_get_##suffix( \ - google::protobuf::internal::UntypedMapBase* m, \ - google::protobuf::internal::MapNodeSizeInfoT size_info, cpp_type key, \ - google::protobuf::MessageLite** value) { \ - return google::protobuf::rust::Get(m, size_info, key, value); \ - } \ - \ - bool proto2_rust_map_remove_##suffix( \ - google::protobuf::internal::UntypedMapBase* m, \ - google::protobuf::internal::MapNodeSizeInfoT size_info, cpp_type key) { \ - return google::protobuf::rust::Remove(m, size_info, key); \ - } \ - \ - void proto2_rust_map_iter_get_##suffix( \ - const google::protobuf::internal::UntypedMapIterator* iter, \ - google::protobuf::internal::MapNodeSizeInfoT size_info, cpp_type* key, \ - google::protobuf::MessageLite** value) { \ - return google::protobuf::rust::IterGet(iter, size_info, key, value); \ +#define DEFINE_KEY_SPECIFIC_MAP_OPERATIONS(cpp_type, suffix) \ + bool proto2_rust_map_insert_##suffix(google::protobuf::internal::UntypedMapBase* m, \ + cpp_type key, \ + google::protobuf::MessageLite* value) { \ + return google::protobuf::rust::Insert(m, key, value); \ + } \ + \ + bool proto2_rust_map_get_##suffix(google::protobuf::internal::UntypedMapBase* m, \ + const google::protobuf::MessageLite* prototype, \ + cpp_type key, \ + google::protobuf::MessageLite** value) { \ + return google::protobuf::rust::Get(m, prototype, key, value); \ + } \ + \ + bool proto2_rust_map_remove_##suffix(google::protobuf::internal::UntypedMapBase* m, \ + const google::protobuf::MessageLite* prototype, \ + cpp_type key) { \ + return google::protobuf::rust::Remove(m, prototype, key); \ + } \ + \ + void proto2_rust_map_iter_get_##suffix( \ + const google::protobuf::internal::UntypedMapIterator* iter, \ + const google::protobuf::MessageLite* prototype, cpp_type* key, \ + google::protobuf::MessageLite** value) { \ + return google::protobuf::rust::IterGet(iter, prototype, key, value); \ } DEFINE_KEY_SPECIFIC_MAP_OPERATIONS(int32_t, i32) diff --git a/src/google/protobuf/compiler/rust/message.cc b/src/google/protobuf/compiler/rust/message.cc index 18efb64c04..509b8a0da5 100644 --- a/src/google/protobuf/compiler/rust/message.cc +++ b/src/google/protobuf/compiler/rust/message.cc @@ -204,8 +204,7 @@ void CppMessageExterns(Context& ctx, const Descriptor& msg) { {"repeated_add_thunk", ThunkName(ctx, msg, "repeated_add")}, {"repeated_clear_thunk", ThunkName(ctx, msg, "repeated_clear")}, {"repeated_copy_from_thunk", ThunkName(ctx, msg, "repeated_copy_from")}, - {"repeated_reserve_thunk", ThunkName(ctx, msg, "repeated_reserve")}, - {"map_size_info_thunk", ThunkName(ctx, msg, "size_info")}}, + {"repeated_reserve_thunk", ThunkName(ctx, msg, "repeated_reserve")}}, R"rs( fn $new_thunk$() -> $pbr$::RawMessage; fn $default_instance_thunk$() -> $pbr$::RawMessage; @@ -218,7 +217,6 @@ void CppMessageExterns(Context& ctx, const Descriptor& msg) { fn $repeated_clear_thunk$(raw: $pbr$::RawRepeatedField); fn $repeated_copy_from_thunk$(dst: $pbr$::RawRepeatedField, src: $pbr$::RawRepeatedField); fn $repeated_reserve_thunk$(raw: $pbr$::RawRepeatedField, additional: usize); - fn $map_size_info_thunk$(i: $pbr$::MapNodeSizeInfoIndex) -> $pbr$::MapNodeSizeInfo; )rs"); } @@ -580,8 +578,7 @@ void MessageProxiedInMapValue(Context& ctx, const Descriptor& msg) { case Kernel::kCpp: for (const auto& t : kMapKeyTypes) { ctx.Emit( - {{"map_size_info_thunk", ThunkName(ctx, msg, "size_info")}, - {"map_insert", + {{"map_insert", absl::StrCat("proto2_rust_map_insert_", t.thunk_ident)}, {"map_remove", absl::StrCat("proto2_rust_map_remove_", t.thunk_ident)}, @@ -610,13 +607,13 @@ void MessageProxiedInMapValue(Context& ctx, const Descriptor& msg) { } unsafe fn map_free(_private: $pbi$::Private, map: &mut $pb$::Map<$key_t$, Self>) { - use $pbr$::MapNodeSizeInfoIndexForType; - unsafe { $pbr$::proto2_rust_map_free(map.as_raw($pbi$::Private), $key_is_string$, $map_size_info_thunk$($key_t$::SIZE_INFO_INDEX)); } + use $pbr$::MapKey; + unsafe { $pbr$::proto2_rust_map_free(map.as_raw($pbi$::Private), $key_t$::CATEGORY, <$pb$::View::<$Msg$> as std::default::Default>::default().raw_msg()); } } fn map_clear(mut map: $pb$::MapMut<$key_t$, Self>) { - use $pbr$::MapNodeSizeInfoIndexForType; - unsafe { $pbr$::proto2_rust_map_clear(map.as_raw($pbi$::Private), $key_is_string$, $map_size_info_thunk$($key_t$::SIZE_INFO_INDEX)); } + use $pbr$::MapKey; + unsafe { $pbr$::proto2_rust_map_clear(map.as_raw($pbi$::Private), $key_t$::CATEGORY, <$pb$::View::<$Msg$> as std::default::Default>::default().raw_msg()); } } fn map_len(map: $pb$::MapView<$key_t$, Self>) -> usize { @@ -624,24 +621,21 @@ void MessageProxiedInMapValue(Context& ctx, const Descriptor& msg) { } fn map_insert(mut map: $pb$::MapMut<$key_t$, Self>, key: $pb$::View<'_, $key_t$>, value: impl $pb$::IntoProxied) -> bool { - use $pbr$::MapNodeSizeInfoIndexForType; unsafe { $pbr$::$map_insert$( map.as_raw($pbi$::Private), - $map_size_info_thunk$($key_t$::SIZE_INFO_INDEX), $key_expr$, value.into_proxied($pbi$::Private).raw_msg()) } } fn map_get<'a>(map: $pb$::MapView<'a, $key_t$, Self>, key: $pb$::View<'_, $key_t$>) -> $Option$<$pb$::View<'a, Self>> { - use $pbr$::MapNodeSizeInfoIndexForType; let key = $key_expr$; let mut value = $std$::mem::MaybeUninit::uninit(); let found = unsafe { $pbr$::$map_get$( map.as_raw($pbi$::Private), - $map_size_info_thunk$($key_t$::SIZE_INFO_INDEX), + <$pb$::View::<$Msg$> as std::default::Default>::default().raw_msg(), key, value.as_mut_ptr()) }; @@ -652,11 +646,10 @@ void MessageProxiedInMapValue(Context& ctx, const Descriptor& msg) { } fn map_remove(mut map: $pb$::MapMut<$key_t$, Self>, key: $pb$::View<'_, $key_t$>) -> bool { - use $pbr$::MapNodeSizeInfoIndexForType; unsafe { $pbr$::$map_remove$( map.as_raw($pbi$::Private), - $map_size_info_thunk$($key_t$::SIZE_INFO_INDEX), + <$pb$::View::<$Msg$> as std::default::Default>::default().raw_msg(), $key_expr$) } } @@ -676,7 +669,6 @@ void MessageProxiedInMapValue(Context& ctx, const Descriptor& msg) { } fn map_iter_next<'a>(iter: &mut $pb$::MapIter<'a, $key_t$, Self>) -> $Option$<($pb$::View<'a, $key_t$>, $pb$::View<'a, Self>)> { - use $pbr$::MapNodeSizeInfoIndexForType; // SAFETY: // - The `MapIter` API forbids the backing map from being mutated for 'a, // and guarantees that it's the correct key and value types. @@ -686,7 +678,7 @@ void MessageProxiedInMapValue(Context& ctx, const Descriptor& msg) { unsafe { iter.as_raw_mut($pbi$::Private).next_unchecked::<$key_t$, Self, _, _>( |iter, key, value| { $pbr$::$map_iter_get$( - iter, $map_size_info_thunk$($key_t$::SIZE_INFO_INDEX), key, value) }, + iter, <$pb$::View::<$Msg$> as std::default::Default>::default().raw_msg(), key, value) }, |ffi_key| $from_ffi_key_expr$, |raw_msg| $Msg$View::new($pbi$::Private, raw_msg) ) @@ -1397,7 +1389,6 @@ void GenerateThunksCc(Context& ctx, const Descriptor& msg) { {"repeated_clear_thunk", ThunkName(ctx, msg, "repeated_clear")}, {"repeated_copy_from_thunk", ThunkName(ctx, msg, "repeated_copy_from")}, {"repeated_reserve_thunk", ThunkName(ctx, msg, "repeated_reserve")}, - {"map_size_info_thunk", ThunkName(ctx, msg, "size_info")}, {"nested_msg_thunks", [&] { for (int i = 0; i < msg.nested_type_count(); ++i) { @@ -1465,23 +1456,6 @@ void GenerateThunksCc(Context& ctx, const Descriptor& msg) { size_t additional) { field->Reserve(field->size() + additional); } - google::protobuf::internal::MapNodeSizeInfoT $map_size_info_thunk$(int32_t i) { - static constexpr google::protobuf::internal::MapNodeSizeInfoT size_infos[] = {)cc" - // LINT.IfChange(size_info_mapping) - R"cc( - google::protobuf::internal::RustMapHelper::SizeInfo(), - google::protobuf::internal::RustMapHelper::SizeInfo(), - google::protobuf::internal::RustMapHelper::SizeInfo(), - google::protobuf::internal::RustMapHelper::SizeInfo() - )cc" - // LINT.ThenChange(//depot/google3/third_party/protobuf/rust/cpp.rs:size_info_mapping) - R"cc( - } - ; - return size_infos[i]; - } $accessor_thunks$ diff --git a/src/google/protobuf/map.h b/src/google/protobuf/map.h index 7b28ffd5aa..7a16d0f698 100644 --- a/src/google/protobuf/map.h +++ b/src/google/protobuf/map.h @@ -1173,6 +1173,18 @@ class RustMapHelper { using NodeAndBucket = UntypedMapBase::NodeAndBucket; using ClearInput = UntypedMapBase::ClearInput; + static void GetSizeAndAlignment(const google::protobuf::MessageLite* m, uint16_t* size, + uint8_t* alignment) { + const auto* class_data = m->GetClassData(); + *size = static_cast(class_data->allocation_size()); + *alignment = class_data->alignment(); + } + + static constexpr MapNodeSizeInfoT MakeSizeInfo(uint16_t size, + uint16_t value_offset) { + return MakeNodeInfo(size, value_offset); + } + template static constexpr MapNodeSizeInfoT SizeInfo() { return Map::Node::size_info();