From d900d6114c1df9bb4c74bb3e00540e975520e9de Mon Sep 17 00:00:00 2001 From: Adam Cozzette Date: Mon, 30 Sep 2024 14:09:02 -0700 Subject: [PATCH] Rust: remove use of `MapNodeSizeInfoT` from generated code We generate these constants to enable map operations, but this is no longer necessary now that we can get the relevant size and alignment information for each message through its vtable. PiperOrigin-RevId: 680712939 --- rust/cpp.rs | 61 +++++---- rust/cpp_kernel/map.cc | 134 +++++++++++++------ src/google/protobuf/compiler/rust/message.cc | 44 ++---- src/google/protobuf/map.h | 12 ++ 4 files changed, 146 insertions(+), 105 deletions(-) diff --git a/rust/cpp.rs b/rust/cpp.rs index 661b3c5613..460910463c 100644 --- a/rust/cpp.rs +++ b/rust/cpp.rs @@ -110,11 +110,6 @@ extern "C" { pub fn proto2_rust_Message_merge_from(dst: RawMessage, src: RawMessage) -> bool; } -/// An opaque type matching MapNodeSizeInfoT from C++. -#[doc(hidden)] -#[repr(transparent)] -pub struct MapNodeSizeInfo(pub i32); - impl Drop for InnerProtoString { fn drop(&mut self) { // SAFETY: `self.owned_ptr` points to a valid std::string object. @@ -767,36 +762,45 @@ impl UntypedMapIterator { } } +// This enum is used to pass some information about the key type of a map to +// C++. The main purpose is to indicate the size of the key so that we can +// determine the correct size and offset information of map entries on the C++ +// side. We also rely on it to indicate whether the key is a string or not. #[doc(hidden)] -#[repr(transparent)] -pub struct MapNodeSizeInfoIndex(i32); +#[repr(u8)] +pub enum MapKeyCategory { + OneByte, + FourBytes, + EightBytes, + StdString, +} #[doc(hidden)] -pub trait MapNodeSizeInfoIndexForType { - const SIZE_INFO_INDEX: MapNodeSizeInfoIndex; +pub trait MapKey { + const CATEGORY: MapKeyCategory; } -macro_rules! generate_map_node_size_info_mapping { - ( $($key:ty, $index:expr;)* ) => { +macro_rules! generate_map_key_impl { + ( $($key:ty, $category:expr;)* ) => { $( - impl MapNodeSizeInfoIndexForType for $key { - const SIZE_INFO_INDEX: MapNodeSizeInfoIndex = MapNodeSizeInfoIndex($index); + impl MapKey for $key { + const CATEGORY: MapKeyCategory = $category; } )* } } -// LINT.IfChange(size_info_mapping) -generate_map_node_size_info_mapping!( - i32, 0; - u32, 0; - i64, 1; - u64, 1; - bool, 2; - ProtoString, 3; +// LINT.IfChange(map_key_category) +generate_map_key_impl!( + bool, MapKeyCategory::OneByte; + i32, MapKeyCategory::FourBytes; + u32, MapKeyCategory::FourBytes; + i64, MapKeyCategory::EightBytes; + u64, MapKeyCategory::EightBytes; + ProtoString, MapKeyCategory::StdString; ); -// LINT.ThenChange(//depot/google3/third_party/protobuf/compiler/rust/message. -// cc:size_info_mapping) +// LINT.ThenChange(//depot/google3/third_party/protobuf/rust/cpp_kernel/map.cc: +// map_key_category) macro_rules! impl_map_primitives { (@impl $(($rust_type:ty, $cpp_type:ty) => [ @@ -809,23 +813,22 @@ macro_rules! impl_map_primitives { extern "C" { pub fn $insert_thunk( m: RawMap, - size_info: MapNodeSizeInfo, key: $cpp_type, value: RawMessage, ) -> bool; pub fn $get_thunk( m: RawMap, - size_info: MapNodeSizeInfo, + prototype: RawMessage, key: $cpp_type, value: *mut RawMessage, ) -> bool; pub fn $iter_get_thunk( iter: &mut UntypedMapIterator, - size_info: MapNodeSizeInfo, + prototype: RawMessage, key: *mut $cpp_type, value: *mut RawMessage, ); - pub fn $remove_thunk(m: RawMap, size_info: MapNodeSizeInfo, key: $cpp_type) -> bool; + pub fn $remove_thunk(m: RawMap, prototype: RawMessage, key: $cpp_type) -> bool; } )* }; @@ -856,8 +859,8 @@ extern "C" { fn proto2_rust_thunk_UntypedMapIterator_increment(iter: &mut UntypedMapIterator); pub fn proto2_rust_map_new() -> RawMap; - pub fn proto2_rust_map_free(m: RawMap, key_is_string: bool, size_info: MapNodeSizeInfo); - pub fn proto2_rust_map_clear(m: RawMap, key_is_string: bool, size_info: MapNodeSizeInfo); + pub fn proto2_rust_map_free(m: RawMap, category: MapKeyCategory, prototype: RawMessage); + pub fn proto2_rust_map_clear(m: RawMap, category: MapKeyCategory, prototype: RawMessage); pub fn proto2_rust_map_size(m: RawMap) -> usize; pub fn proto2_rust_map_iter(m: RawMap) -> UntypedMapIterator; } diff --git a/rust/cpp_kernel/map.cc b/rust/cpp_kernel/map.cc index 45b996b92d..7925d2073c 100644 --- a/rust/cpp_kernel/map.cc +++ b/rust/cpp_kernel/map.cc @@ -6,6 +6,7 @@ #include #include +#include "absl/log/absl_log.h" #include "google/protobuf/map.h" #include "google/protobuf/message.h" #include "google/protobuf/message_lite.h" @@ -30,6 +31,26 @@ template using KeyMap = internal::KeyMapBase< internal::KeyForBase::type>>; +internal::MapNodeSizeInfoT GetSizeInfo(size_t key_size, + const google::protobuf::MessageLite* value) { + // Each map node consists of a NodeBase followed by a std::pair. + // We need to compute the offset of the value and the total size of the node. + size_t node_and_key_size = sizeof(internal::NodeBase) + key_size; + uint16_t value_size; + uint8_t value_alignment; + internal::RustMapHelper::GetSizeAndAlignment(value, &value_size, + &value_alignment); + // Round node_and_key_size up to the nearest multiple of value_alignment. + uint16_t offset = + (((node_and_key_size - 1) / value_alignment) + 1) * value_alignment; + return internal::RustMapHelper::MakeSizeInfo(offset + value_size, offset); +} + +template +internal::MapNodeSizeInfoT GetSizeInfo(const google::protobuf::MessageLite* value) { + return GetSizeInfo(sizeof(Key), value); +} + template void DestroyMapNode(internal::UntypedMapBase* m, internal::NodeBase* node, internal::MapNodeSizeInfoT size_info) { @@ -42,8 +63,9 @@ void DestroyMapNode(internal::UntypedMapBase* m, internal::NodeBase* node, } template -bool Insert(internal::UntypedMapBase* m, internal::MapNodeSizeInfoT size_info, - Key key, MessageLite* value) { +bool Insert(internal::UntypedMapBase* m, Key key, MessageLite* value) { + internal::MapNodeSizeInfoT size_info = + GetSizeInfo::type>(value); internal::NodeBase* node = internal::RustMapHelper::AllocNode(m, size_info); if constexpr (std::is_same::value) { new (node->GetVoidKey()) std::string(key.ptr, key.len); @@ -89,8 +111,10 @@ internal::RustMapHelper::NodeAndBucket FindHelper(Map* m, } template -bool Get(internal::UntypedMapBase* m, internal::MapNodeSizeInfoT size_info, +bool Get(internal::UntypedMapBase* m, const google::protobuf::MessageLite* prototype, Key key, MessageLite** value) { + internal::MapNodeSizeInfoT size_info = + GetSizeInfo::type>(prototype); auto* map_base = static_cast*>(m); internal::RustMapHelper::NodeAndBucket result = FindHelper(map_base, key); if (result.node == nullptr) { @@ -101,8 +125,10 @@ bool Get(internal::UntypedMapBase* m, internal::MapNodeSizeInfoT size_info, } template -bool Remove(internal::UntypedMapBase* m, internal::MapNodeSizeInfoT size_info, +bool Remove(internal::UntypedMapBase* m, const google::protobuf::MessageLite* prototype, Key key) { + internal::MapNodeSizeInfoT size_info = + GetSizeInfo::type>(prototype); auto* map_base = static_cast*>(m); internal::RustMapHelper::NodeAndBucket result = FindHelper(map_base, key); if (result.node == nullptr) { @@ -115,8 +141,10 @@ bool Remove(internal::UntypedMapBase* m, internal::MapNodeSizeInfoT size_info, template void IterGet(const internal::UntypedMapIterator* iter, - internal::MapNodeSizeInfoT size_info, Key* key, + const google::protobuf::MessageLite* prototype, Key* key, MessageLite** value) { + internal::MapNodeSizeInfoT size_info = + GetSizeInfo::type>(prototype); internal::NodeBase* node = iter->node_; if constexpr (std::is_same::value) { const std::string* s = static_cast(node->GetVoidKey()); @@ -127,11 +155,37 @@ void IterGet(const internal::UntypedMapIterator* iter, *value = static_cast(node->GetVoidValue(size_info)); } -void ClearMap(internal::UntypedMapBase* m, internal::MapNodeSizeInfoT size_info, - bool key_is_string, bool reset_table) { +// LINT.IfChange(map_key_category) +enum class MapKeyCategory : uint8_t { + kOneByte = 0, + kFourBytes = 1, + kEightBytes = 2, + kStdString = 3, +}; +// LINT.ThenChange(//depot/google3/third_party/protobuf/rust/cpp.rs:map_key_category) + +size_t KeySize(MapKeyCategory category) { + switch (category) { + case MapKeyCategory::kOneByte: + return 1; + case MapKeyCategory::kFourBytes: + return 4; + case MapKeyCategory::kEightBytes: + return 8; + case MapKeyCategory::kStdString: + return sizeof(std::string); + default: + ABSL_DLOG(FATAL) << "Unexpected value of MapKeyCategory enum"; + } +} + +void ClearMap(internal::UntypedMapBase* m, MapKeyCategory category, + bool reset_table, const google::protobuf::MessageLite* prototype) { + internal::MapNodeSizeInfoT size_info = + GetSizeInfo(KeySize(category), prototype); if (internal::RustMapHelper::IsGlobalEmptyTable(m)) return; uint8_t bits = internal::RustMapHelper::kValueIsProto; - if (key_is_string) { + if (category == MapKeyCategory::kStdString) { bits |= internal::RustMapHelper::kKeyIsString; } internal::RustMapHelper::ClearTable( @@ -156,17 +210,16 @@ google::protobuf::internal::UntypedMapBase* proto2_rust_map_new() { } void proto2_rust_map_free(google::protobuf::internal::UntypedMapBase* m, - bool key_is_string, - google::protobuf::internal::MapNodeSizeInfoT size_info) { - google::protobuf::rust::ClearMap(m, size_info, key_is_string, - /* reset_table = */ false); + google::protobuf::rust::MapKeyCategory category, + const google::protobuf::MessageLite* prototype) { + google::protobuf::rust::ClearMap(m, category, /* reset_table = */ false, prototype); delete m; } void proto2_rust_map_clear(google::protobuf::internal::UntypedMapBase* m, - bool key_is_string, - google::protobuf::internal::MapNodeSizeInfoT size_info) { - google::protobuf::rust::ClearMap(m, size_info, key_is_string, /* reset_table = */ true); + google::protobuf::rust::MapKeyCategory category, + const google::protobuf::MessageLite* prototype) { + google::protobuf::rust::ClearMap(m, category, /* reset_table = */ true, prototype); } size_t proto2_rust_map_size(google::protobuf::internal::UntypedMapBase* m) { @@ -178,32 +231,31 @@ google::protobuf::internal::UntypedMapIterator proto2_rust_map_iter( return m->begin(); } -#define DEFINE_KEY_SPECIFIC_MAP_OPERATIONS(cpp_type, suffix) \ - bool proto2_rust_map_insert_##suffix( \ - google::protobuf::internal::UntypedMapBase* m, \ - google::protobuf::internal::MapNodeSizeInfoT size_info, cpp_type key, \ - google::protobuf::MessageLite* value) { \ - return google::protobuf::rust::Insert(m, size_info, key, value); \ - } \ - \ - bool proto2_rust_map_get_##suffix( \ - google::protobuf::internal::UntypedMapBase* m, \ - google::protobuf::internal::MapNodeSizeInfoT size_info, cpp_type key, \ - google::protobuf::MessageLite** value) { \ - return google::protobuf::rust::Get(m, size_info, key, value); \ - } \ - \ - bool proto2_rust_map_remove_##suffix( \ - google::protobuf::internal::UntypedMapBase* m, \ - google::protobuf::internal::MapNodeSizeInfoT size_info, cpp_type key) { \ - return google::protobuf::rust::Remove(m, size_info, key); \ - } \ - \ - void proto2_rust_map_iter_get_##suffix( \ - const google::protobuf::internal::UntypedMapIterator* iter, \ - google::protobuf::internal::MapNodeSizeInfoT size_info, cpp_type* key, \ - google::protobuf::MessageLite** value) { \ - return google::protobuf::rust::IterGet(iter, size_info, key, value); \ +#define DEFINE_KEY_SPECIFIC_MAP_OPERATIONS(cpp_type, suffix) \ + bool proto2_rust_map_insert_##suffix(google::protobuf::internal::UntypedMapBase* m, \ + cpp_type key, \ + google::protobuf::MessageLite* value) { \ + return google::protobuf::rust::Insert(m, key, value); \ + } \ + \ + bool proto2_rust_map_get_##suffix(google::protobuf::internal::UntypedMapBase* m, \ + const google::protobuf::MessageLite* prototype, \ + cpp_type key, \ + google::protobuf::MessageLite** value) { \ + return google::protobuf::rust::Get(m, prototype, key, value); \ + } \ + \ + bool proto2_rust_map_remove_##suffix(google::protobuf::internal::UntypedMapBase* m, \ + const google::protobuf::MessageLite* prototype, \ + cpp_type key) { \ + return google::protobuf::rust::Remove(m, prototype, key); \ + } \ + \ + void proto2_rust_map_iter_get_##suffix( \ + const google::protobuf::internal::UntypedMapIterator* iter, \ + const google::protobuf::MessageLite* prototype, cpp_type* key, \ + google::protobuf::MessageLite** value) { \ + return google::protobuf::rust::IterGet(iter, prototype, key, value); \ } DEFINE_KEY_SPECIFIC_MAP_OPERATIONS(int32_t, i32) diff --git a/src/google/protobuf/compiler/rust/message.cc b/src/google/protobuf/compiler/rust/message.cc index 18efb64c04..509b8a0da5 100644 --- a/src/google/protobuf/compiler/rust/message.cc +++ b/src/google/protobuf/compiler/rust/message.cc @@ -204,8 +204,7 @@ void CppMessageExterns(Context& ctx, const Descriptor& msg) { {"repeated_add_thunk", ThunkName(ctx, msg, "repeated_add")}, {"repeated_clear_thunk", ThunkName(ctx, msg, "repeated_clear")}, {"repeated_copy_from_thunk", ThunkName(ctx, msg, "repeated_copy_from")}, - {"repeated_reserve_thunk", ThunkName(ctx, msg, "repeated_reserve")}, - {"map_size_info_thunk", ThunkName(ctx, msg, "size_info")}}, + {"repeated_reserve_thunk", ThunkName(ctx, msg, "repeated_reserve")}}, R"rs( fn $new_thunk$() -> $pbr$::RawMessage; fn $default_instance_thunk$() -> $pbr$::RawMessage; @@ -218,7 +217,6 @@ void CppMessageExterns(Context& ctx, const Descriptor& msg) { fn $repeated_clear_thunk$(raw: $pbr$::RawRepeatedField); fn $repeated_copy_from_thunk$(dst: $pbr$::RawRepeatedField, src: $pbr$::RawRepeatedField); fn $repeated_reserve_thunk$(raw: $pbr$::RawRepeatedField, additional: usize); - fn $map_size_info_thunk$(i: $pbr$::MapNodeSizeInfoIndex) -> $pbr$::MapNodeSizeInfo; )rs"); } @@ -580,8 +578,7 @@ void MessageProxiedInMapValue(Context& ctx, const Descriptor& msg) { case Kernel::kCpp: for (const auto& t : kMapKeyTypes) { ctx.Emit( - {{"map_size_info_thunk", ThunkName(ctx, msg, "size_info")}, - {"map_insert", + {{"map_insert", absl::StrCat("proto2_rust_map_insert_", t.thunk_ident)}, {"map_remove", absl::StrCat("proto2_rust_map_remove_", t.thunk_ident)}, @@ -610,13 +607,13 @@ void MessageProxiedInMapValue(Context& ctx, const Descriptor& msg) { } unsafe fn map_free(_private: $pbi$::Private, map: &mut $pb$::Map<$key_t$, Self>) { - use $pbr$::MapNodeSizeInfoIndexForType; - unsafe { $pbr$::proto2_rust_map_free(map.as_raw($pbi$::Private), $key_is_string$, $map_size_info_thunk$($key_t$::SIZE_INFO_INDEX)); } + use $pbr$::MapKey; + unsafe { $pbr$::proto2_rust_map_free(map.as_raw($pbi$::Private), $key_t$::CATEGORY, <$pb$::View::<$Msg$> as std::default::Default>::default().raw_msg()); } } fn map_clear(mut map: $pb$::MapMut<$key_t$, Self>) { - use $pbr$::MapNodeSizeInfoIndexForType; - unsafe { $pbr$::proto2_rust_map_clear(map.as_raw($pbi$::Private), $key_is_string$, $map_size_info_thunk$($key_t$::SIZE_INFO_INDEX)); } + use $pbr$::MapKey; + unsafe { $pbr$::proto2_rust_map_clear(map.as_raw($pbi$::Private), $key_t$::CATEGORY, <$pb$::View::<$Msg$> as std::default::Default>::default().raw_msg()); } } fn map_len(map: $pb$::MapView<$key_t$, Self>) -> usize { @@ -624,24 +621,21 @@ void MessageProxiedInMapValue(Context& ctx, const Descriptor& msg) { } fn map_insert(mut map: $pb$::MapMut<$key_t$, Self>, key: $pb$::View<'_, $key_t$>, value: impl $pb$::IntoProxied) -> bool { - use $pbr$::MapNodeSizeInfoIndexForType; unsafe { $pbr$::$map_insert$( map.as_raw($pbi$::Private), - $map_size_info_thunk$($key_t$::SIZE_INFO_INDEX), $key_expr$, value.into_proxied($pbi$::Private).raw_msg()) } } fn map_get<'a>(map: $pb$::MapView<'a, $key_t$, Self>, key: $pb$::View<'_, $key_t$>) -> $Option$<$pb$::View<'a, Self>> { - use $pbr$::MapNodeSizeInfoIndexForType; let key = $key_expr$; let mut value = $std$::mem::MaybeUninit::uninit(); let found = unsafe { $pbr$::$map_get$( map.as_raw($pbi$::Private), - $map_size_info_thunk$($key_t$::SIZE_INFO_INDEX), + <$pb$::View::<$Msg$> as std::default::Default>::default().raw_msg(), key, value.as_mut_ptr()) }; @@ -652,11 +646,10 @@ void MessageProxiedInMapValue(Context& ctx, const Descriptor& msg) { } fn map_remove(mut map: $pb$::MapMut<$key_t$, Self>, key: $pb$::View<'_, $key_t$>) -> bool { - use $pbr$::MapNodeSizeInfoIndexForType; unsafe { $pbr$::$map_remove$( map.as_raw($pbi$::Private), - $map_size_info_thunk$($key_t$::SIZE_INFO_INDEX), + <$pb$::View::<$Msg$> as std::default::Default>::default().raw_msg(), $key_expr$) } } @@ -676,7 +669,6 @@ void MessageProxiedInMapValue(Context& ctx, const Descriptor& msg) { } fn map_iter_next<'a>(iter: &mut $pb$::MapIter<'a, $key_t$, Self>) -> $Option$<($pb$::View<'a, $key_t$>, $pb$::View<'a, Self>)> { - use $pbr$::MapNodeSizeInfoIndexForType; // SAFETY: // - The `MapIter` API forbids the backing map from being mutated for 'a, // and guarantees that it's the correct key and value types. @@ -686,7 +678,7 @@ void MessageProxiedInMapValue(Context& ctx, const Descriptor& msg) { unsafe { iter.as_raw_mut($pbi$::Private).next_unchecked::<$key_t$, Self, _, _>( |iter, key, value| { $pbr$::$map_iter_get$( - iter, $map_size_info_thunk$($key_t$::SIZE_INFO_INDEX), key, value) }, + iter, <$pb$::View::<$Msg$> as std::default::Default>::default().raw_msg(), key, value) }, |ffi_key| $from_ffi_key_expr$, |raw_msg| $Msg$View::new($pbi$::Private, raw_msg) ) @@ -1397,7 +1389,6 @@ void GenerateThunksCc(Context& ctx, const Descriptor& msg) { {"repeated_clear_thunk", ThunkName(ctx, msg, "repeated_clear")}, {"repeated_copy_from_thunk", ThunkName(ctx, msg, "repeated_copy_from")}, {"repeated_reserve_thunk", ThunkName(ctx, msg, "repeated_reserve")}, - {"map_size_info_thunk", ThunkName(ctx, msg, "size_info")}, {"nested_msg_thunks", [&] { for (int i = 0; i < msg.nested_type_count(); ++i) { @@ -1465,23 +1456,6 @@ void GenerateThunksCc(Context& ctx, const Descriptor& msg) { size_t additional) { field->Reserve(field->size() + additional); } - google::protobuf::internal::MapNodeSizeInfoT $map_size_info_thunk$(int32_t i) { - static constexpr google::protobuf::internal::MapNodeSizeInfoT size_infos[] = {)cc" - // LINT.IfChange(size_info_mapping) - R"cc( - google::protobuf::internal::RustMapHelper::SizeInfo(), - google::protobuf::internal::RustMapHelper::SizeInfo(), - google::protobuf::internal::RustMapHelper::SizeInfo(), - google::protobuf::internal::RustMapHelper::SizeInfo() - )cc" - // LINT.ThenChange(//depot/google3/third_party/protobuf/rust/cpp.rs:size_info_mapping) - R"cc( - } - ; - return size_infos[i]; - } $accessor_thunks$ diff --git a/src/google/protobuf/map.h b/src/google/protobuf/map.h index 7b28ffd5aa..7a16d0f698 100644 --- a/src/google/protobuf/map.h +++ b/src/google/protobuf/map.h @@ -1173,6 +1173,18 @@ class RustMapHelper { using NodeAndBucket = UntypedMapBase::NodeAndBucket; using ClearInput = UntypedMapBase::ClearInput; + static void GetSizeAndAlignment(const google::protobuf::MessageLite* m, uint16_t* size, + uint8_t* alignment) { + const auto* class_data = m->GetClassData(); + *size = static_cast(class_data->allocation_size()); + *alignment = class_data->alignment(); + } + + static constexpr MapNodeSizeInfoT MakeSizeInfo(uint16_t size, + uint16_t value_offset) { + return MakeNodeInfo(size, value_offset); + } + template static constexpr MapNodeSizeInfoT SizeInfo() { return Map::Node::size_info();