From 6ab302d3a3c0a09b2f0d0e951600e4b5b02686d0 Mon Sep 17 00:00:00 2001 From: Adam Cozzette Date: Tue, 30 Jul 2024 12:03:36 -0700 Subject: [PATCH] Rust: cut down on the amount of generated C++ code needed for maps With the C++ kernel for Rust, we currently need to generate quite a few C++ thunks for operations on map fields. For each message we generate, we generate these thunks for all possible map types that could have that message as a value. These operations are for things such as insertion, removal, clearing, iterating, etc. The reason we do this is that templated types don't play well with FFI, so we effectively need separate FFI endpoints for every possible combination of key and value types used (or even potentially used) as a map field. This CL fixes the problem by replacing the generated thunks with functions in the runtime that can operate on `proto2::MessageLite*` without needing to care about the specific message type. The way it works is that we implement the operations using either `UntypedMapBase` (the base class of all map types, which knows nothing about the key and value types) or `KeyMapBase`, which knows the key type but not the value type. I roughly followed the example of the table-driven parser, which has a similar problem of needing to operate generically on maps without having access to the concrete types. I removed 54 thunks per message (that's 6 key types times 9 operations per key), but had to add two new thunks per message: - The `size_info` thunk looks up the `MapNodeSizeInfoT`, which is stored in a small constant table. The important thing here is an offset indicating where to look for the value in each map entry. This offset can be different for every pair of key and value types, but we can safely assume that the result does not depend on the signedness of the key. As a result we only need to store four entries per message: one each for i32, i64, bool, and string. - The `placement_new` thunk move-constructs a message in place. We need this to be able to efficiently implement map insertion. There are two big things that this CL does not address yet but which I plan to follow up on: - Enums still generate many map-related C++ thunks that could be replaced with a common implementation. This should actually be much easier to handle than messages, because every enum has the same representation as an i32. - We still generate six `ProxiedInMapValue` implementations for every message, but it should be possible to replace these with a blanket implementation that works for all message types. PiperOrigin-RevId: 657681421 --- rust/cpp.rs | 104 ++++++++- rust/cpp_kernel/BUILD | 1 + rust/cpp_kernel/map.cc | 188 ++++++++++++++++ rust/cpp_kernel/map.h | 8 +- rust/test/shared/accessors_map_test.rs | 50 +++++ src/google/protobuf/compiler/rust/enum.cc | 3 +- src/google/protobuf/compiler/rust/message.cc | 224 ++++++++----------- src/google/protobuf/map.h | 66 ++++++ src/google/protobuf/message_lite.h | 2 + 9 files changed, 509 insertions(+), 137 deletions(-) diff --git a/rust/cpp.rs b/rust/cpp.rs index 4410f21d51..4d3358a8e8 100644 --- a/rust/cpp.rs +++ b/rust/cpp.rs @@ -95,6 +95,11 @@ pub struct InnerProtoString { owned_ptr: CppStdString, } +/// An opaque type matching MapNodeSizeInfoT from C++. +#[doc(hidden)] +#[repr(transparent)] +pub struct MapNodeSizeInfo(pub i32); + impl Drop for InnerProtoString { fn drop(&mut self) { // SAFETY: `self.owned_ptr` points to a valid std::string object. @@ -690,9 +695,11 @@ impl UntypedMapIterator { _private: Private, iter_get_thunk: unsafe extern "C" fn( iter: &mut UntypedMapIterator, + size_info: MapNodeSizeInfo, key: *mut FfiKey, value: *mut FfiValue, ), + size_info: MapNodeSizeInfo, from_ffi_key: impl FnOnce(FfiKey) -> View<'a, K>, from_ffi_value: impl FnOnce(FfiValue) -> View<'a, V>, ) -> Option<(View<'a, K>, View<'a, V>)> @@ -710,7 +717,7 @@ impl UntypedMapIterator { // - The iterator is not at the end (node is non-null). // - `ffi_key` and `ffi_value` are not read (as uninit) as promised by the // caller. - unsafe { (iter_get_thunk)(self, ffi_key.as_mut_ptr(), ffi_value.as_mut_ptr()) } + unsafe { (iter_get_thunk)(self, size_info, ffi_key.as_mut_ptr(), ffi_value.as_mut_ptr()) } // SAFETY: // - The backing map is alive as promised by the caller. @@ -733,8 +740,100 @@ impl UntypedMapIterator { } } +#[doc(hidden)] +#[repr(transparent)] +pub struct MapNodeSizeInfoIndex(i32); + +#[doc(hidden)] +pub trait MapNodeSizeInfoIndexForType { + const SIZE_INFO_INDEX: MapNodeSizeInfoIndex; +} + +macro_rules! generate_map_node_size_info_mapping { + ( $($key:ty, $index:expr;)* ) => { + $( + impl MapNodeSizeInfoIndexForType for $key { + const SIZE_INFO_INDEX: MapNodeSizeInfoIndex = MapNodeSizeInfoIndex($index); + } + )* + } +} + +// LINT.IfChange(size_info_mapping) +generate_map_node_size_info_mapping!( + i32, 0; + u32, 0; + i64, 1; + u64, 1; + bool, 2; + ProtoString, 3; +); +// LINT.ThenChange(//depot/google3/third_party/protobuf/compiler/rust/message. +// cc:size_info_mapping) + +macro_rules! impl_map_primitives { + (@impl $(($rust_type:ty, $cpp_type:ty) => [ + $insert_thunk:ident, + $get_thunk:ident, + $iter_get_thunk:ident, + $remove_thunk:ident, + ]),* $(,)?) => { + $( + extern "C" { + pub fn $insert_thunk( + m: RawMap, + size_info: MapNodeSizeInfo, + key: $cpp_type, + value: RawMessage, + placement_new: unsafe extern "C" fn(*mut c_void, m: RawMessage), + ) -> bool; + pub fn $get_thunk( + m: RawMap, + size_info: MapNodeSizeInfo, + key: $cpp_type, + value: *mut RawMessage, + ) -> bool; + pub fn $iter_get_thunk( + iter: &mut UntypedMapIterator, + size_info: MapNodeSizeInfo, + key: *mut $cpp_type, + value: *mut RawMessage, + ); + pub fn $remove_thunk(m: RawMap, size_info: MapNodeSizeInfo, key: $cpp_type) -> bool; + } + )* + }; + ($($rust_type:ty, $cpp_type:ty;)* $(,)?) => { + paste!{ + impl_map_primitives!(@impl $( + ($rust_type, $cpp_type) => [ + [< proto2_rust_map_insert_ $rust_type >], + [< proto2_rust_map_get_ $rust_type >], + [< proto2_rust_map_iter_get_ $rust_type >], + [< proto2_rust_map_remove_ $rust_type >], + ], + )*); + } + }; +} + +impl_map_primitives!( + i32, i32; + u32, u32; + i64, i64; + u64, u64; + bool, bool; + ProtoString, PtrAndLen; +); + extern "C" { fn proto2_rust_thunk_UntypedMapIterator_increment(iter: &mut UntypedMapIterator); + + pub fn proto2_rust_map_new() -> RawMap; + pub fn proto2_rust_map_free(m: RawMap, key_is_string: bool, size_info: MapNodeSizeInfo); + pub fn proto2_rust_map_clear(m: RawMap, key_is_string: bool, size_info: MapNodeSizeInfo); + pub fn proto2_rust_map_size(m: RawMap) -> usize; + pub fn proto2_rust_map_iter(m: RawMap) -> UntypedMapIterator; } macro_rules! impl_ProxiedInMapValue_for_non_generated_value_types { @@ -748,7 +847,7 @@ macro_rules! impl_ProxiedInMapValue_for_non_generated_value_types { fn [< proto2_rust_thunk_Map_ $key_t _ $t _insert >](m: RawMap, key: $ffi_key_t, value: $ffi_value_t) -> bool; fn [< proto2_rust_thunk_Map_ $key_t _ $t _get >](m: RawMap, key: $ffi_key_t, value: *mut $ffi_view_t) -> bool; fn [< proto2_rust_thunk_Map_ $key_t _ $t _iter >](m: RawMap) -> UntypedMapIterator; - fn [< proto2_rust_thunk_Map_ $key_t _ $t _iter_get >](iter: &mut UntypedMapIterator, key: *mut $ffi_key_t, value: *mut $ffi_view_t); + fn [< proto2_rust_thunk_Map_ $key_t _ $t _iter_get >](iter: &mut UntypedMapIterator, size_info: MapNodeSizeInfo, key: *mut $ffi_key_t, value: *mut $ffi_view_t); fn [< proto2_rust_thunk_Map_ $key_t _ $t _remove >](m: RawMap, key: $ffi_key_t, value: *mut $ffi_view_t) -> bool; } @@ -829,6 +928,7 @@ macro_rules! impl_ProxiedInMapValue_for_non_generated_value_types { iter.as_raw_mut(Private).next_unchecked::<$key_t, Self, _, _>( Private, [< proto2_rust_thunk_Map_ $key_t _ $t _iter_get >], + MapNodeSizeInfo(0), $from_ffi_key, $from_ffi_value, ) diff --git a/rust/cpp_kernel/BUILD b/rust/cpp_kernel/BUILD index 7d0a256f1c..c1c1cdb7e1 100644 --- a/rust/cpp_kernel/BUILD +++ b/rust/cpp_kernel/BUILD @@ -27,6 +27,7 @@ cc_library( "//src/google/protobuf:protobuf_lite", "@com_google_absl//absl/log:absl_check", "@com_google_absl//absl/log:absl_log", + "@com_google_absl//absl/strings:string_view", ], ) diff --git a/rust/cpp_kernel/map.cc b/rust/cpp_kernel/map.cc index 2418c7274c..60f53c7df0 100644 --- a/rust/cpp_kernel/map.cc +++ b/rust/cpp_kernel/map.cc @@ -1,12 +1,137 @@ #include "rust/cpp_kernel/map.h" +#include #include #include +#include #include #include "google/protobuf/map.h" +#include "google/protobuf/message_lite.h" #include "rust/cpp_kernel/strings.h" +namespace google { +namespace protobuf { +namespace rust { +namespace { + +template +struct FromViewType { + using type = T; +}; + +template <> +struct FromViewType { + using type = std::string; +}; + +template +using KeyMap = internal::KeyMapBase< + internal::KeyForBase::type>>; + +template +void DestroyMapNode(internal::UntypedMapBase* m, internal::NodeBase* node, + internal::MapNodeSizeInfoT size_info) { + if constexpr (std::is_same::value) { + static_cast(node->GetVoidKey())->~basic_string(); + } + internal::RustMapHelper::DestroyMessage( + static_cast(node->GetVoidValue(size_info))); + internal::RustMapHelper::DeallocNode(m, node, size_info); +} + +template +bool Insert(internal::UntypedMapBase* m, internal::MapNodeSizeInfoT size_info, + Key key, MessageLite* value, + void (*placement_new)(void*, MessageLite*)) { + internal::NodeBase* node = internal::RustMapHelper::AllocNode(m, size_info); + if constexpr (std::is_same::value) { + new (node->GetVoidKey()) std::string(key.ptr, key.len); + } else { + *static_cast(node->GetVoidKey()) = key; + } + void* value_ptr = node->GetVoidValue(size_info); + placement_new(value_ptr, value); + node = internal::RustMapHelper::InsertOrReplaceNode( + static_cast*>(m), node); + if (node == nullptr) { + return true; + } + DestroyMapNode(m, node, size_info); + return false; +} + +template ::value>::type> +internal::RustMapHelper::NodeAndBucket FindHelper(Map* m, Key key) { + return internal::RustMapHelper::FindHelper( + m, static_cast>(key)); +} + +template +internal::RustMapHelper::NodeAndBucket FindHelper(Map* m, + google::protobuf::rust::PtrAndLen key) { + return internal::RustMapHelper::FindHelper( + m, absl::string_view(key.ptr, key.len)); +} + +template +bool Get(internal::UntypedMapBase* m, internal::MapNodeSizeInfoT size_info, + Key key, MessageLite** value) { + auto* map_base = static_cast*>(m); + internal::RustMapHelper::NodeAndBucket result = FindHelper(map_base, key); + if (result.node == nullptr) { + return false; + } + *value = static_cast(result.node->GetVoidValue(size_info)); + return true; +} + +template +bool Remove(internal::UntypedMapBase* m, internal::MapNodeSizeInfoT size_info, + Key key) { + auto* map_base = static_cast*>(m); + internal::RustMapHelper::NodeAndBucket result = FindHelper(map_base, key); + if (result.node == nullptr) { + return false; + } + internal::RustMapHelper::EraseNoDestroy(map_base, result.bucket, result.node); + DestroyMapNode(m, result.node, size_info); + return true; +} + +template +void IterGet(const internal::UntypedMapIterator* iter, + internal::MapNodeSizeInfoT size_info, Key* key, + MessageLite** value) { + internal::NodeBase* node = iter->node_; + if constexpr (std::is_same::value) { + const std::string* s = static_cast(node->GetVoidKey()); + *key = PtrAndLen(s->data(), s->size()); + } else { + *key = *static_cast(node->GetVoidKey()); + } + *value = static_cast(node->GetVoidValue(size_info)); +} + +void ClearMap(internal::UntypedMapBase* m, internal::MapNodeSizeInfoT size_info, + bool key_is_string, bool reset_table) { + if (internal::RustMapHelper::IsGlobalEmptyTable(m)) return; + uint8_t bits = internal::RustMapHelper::kValueIsProto; + if (key_is_string) { + bits |= internal::RustMapHelper::kKeyIsString; + } + internal::RustMapHelper::ClearTable( + m, internal::RustMapHelper::ClearInput{size_info, bits, reset_table, + /* destroy_node = */ nullptr}); +} + +} // namespace +} // namespace rust +} // namespace protobuf +} // namespace google + extern "C" { void proto2_rust_thunk_UntypedMapIterator_increment( @@ -14,6 +139,69 @@ void proto2_rust_thunk_UntypedMapIterator_increment( iter->PlusPlus(); } +google::protobuf::internal::UntypedMapBase* proto2_rust_map_new() { + return new google::protobuf::internal::UntypedMapBase(/* arena = */ nullptr); +} + +void proto2_rust_map_free(google::protobuf::internal::UntypedMapBase* m, + bool key_is_string, + google::protobuf::internal::MapNodeSizeInfoT size_info) { + google::protobuf::rust::ClearMap(m, size_info, key_is_string, + /* reset_table = */ false); + delete m; +} + +void proto2_rust_map_clear(google::protobuf::internal::UntypedMapBase* m, + bool key_is_string, + google::protobuf::internal::MapNodeSizeInfoT size_info) { + google::protobuf::rust::ClearMap(m, size_info, key_is_string, /* reset_table = */ true); +} + +size_t proto2_rust_map_size(google::protobuf::internal::UntypedMapBase* m) { + return m->size(); +} + +google::protobuf::internal::UntypedMapIterator proto2_rust_map_iter( + google::protobuf::internal::UntypedMapBase* m) { + return m->begin(); +} + +#define DEFINE_KEY_SPECIFIC_MAP_OPERATIONS(cpp_type, suffix) \ + bool proto2_rust_map_insert_##suffix( \ + google::protobuf::internal::UntypedMapBase* m, \ + google::protobuf::internal::MapNodeSizeInfoT size_info, cpp_type key, \ + google::protobuf::MessageLite* value, \ + void (*placement_new)(void*, google::protobuf::MessageLite*)) { \ + return google::protobuf::rust::Insert(m, size_info, key, value, placement_new); \ + } \ + \ + bool proto2_rust_map_get_##suffix( \ + google::protobuf::internal::UntypedMapBase* m, \ + google::protobuf::internal::MapNodeSizeInfoT size_info, cpp_type key, \ + google::protobuf::MessageLite** value) { \ + return google::protobuf::rust::Get(m, size_info, key, value); \ + } \ + \ + bool proto2_rust_map_remove_##suffix( \ + google::protobuf::internal::UntypedMapBase* m, \ + google::protobuf::internal::MapNodeSizeInfoT size_info, cpp_type key) { \ + return google::protobuf::rust::Remove(m, size_info, key); \ + } \ + \ + void proto2_rust_map_iter_get_##suffix( \ + const google::protobuf::internal::UntypedMapIterator* iter, \ + google::protobuf::internal::MapNodeSizeInfoT size_info, cpp_type* key, \ + google::protobuf::MessageLite** value) { \ + return google::protobuf::rust::IterGet(iter, size_info, key, value); \ + } + +DEFINE_KEY_SPECIFIC_MAP_OPERATIONS(int32_t, i32) +DEFINE_KEY_SPECIFIC_MAP_OPERATIONS(uint32_t, u32) +DEFINE_KEY_SPECIFIC_MAP_OPERATIONS(int64_t, i64) +DEFINE_KEY_SPECIFIC_MAP_OPERATIONS(uint64_t, u64) +DEFINE_KEY_SPECIFIC_MAP_OPERATIONS(bool, bool) +DEFINE_KEY_SPECIFIC_MAP_OPERATIONS(google::protobuf::rust::PtrAndLen, ProtoString) + __PB_RUST_EXPOSE_SCALAR_MAP_METHODS_FOR_VALUE_TYPE(int32_t, i32, int32_t, int32_t, value, cpp_value); __PB_RUST_EXPOSE_SCALAR_MAP_METHODS_FOR_VALUE_TYPE(uint32_t, u32, uint32_t, diff --git a/rust/cpp_kernel/map.h b/rust/cpp_kernel/map.h index c16f14b1f3..e568222edb 100644 --- a/rust/cpp_kernel/map.h +++ b/rust/cpp_kernel/map.h @@ -4,6 +4,10 @@ #include #include +#include "google/protobuf/map.h" +#include "google/protobuf/message_lite.h" +#include "rust/cpp_kernel/strings.h" + namespace google { namespace protobuf { namespace rust { @@ -74,8 +78,8 @@ auto MakeCleanup(T value) { return google::protobuf::internal::UntypedMapIterator::FromTyped(m->cbegin()); \ } \ void proto2_rust_thunk_Map_##rust_key_ty##_##rust_value_ty##_iter_get( \ - const google::protobuf::internal::UntypedMapIterator* iter, ffi_key_ty* key, \ - ffi_view_ty* value) { \ + const google::protobuf::internal::UntypedMapIterator* iter, int32_t, \ + ffi_key_ty* key, ffi_view_ty* value) { \ auto typed_iter = \ iter->ToTyped::const_iterator>(); \ const auto& cpp_key = typed_iter->first; \ diff --git a/rust/test/shared/accessors_map_test.rs b/rust/test/shared/accessors_map_test.rs index 4ec026eb1f..942efb6efd 100644 --- a/rust/test/shared/accessors_map_test.rs +++ b/rust/test/shared/accessors_map_test.rs @@ -212,6 +212,56 @@ fn test_map_setter() { } } +#[test] +fn test_map_creation_with_message_values() { + // Maps are usually created and owned by a parent message, but let's verify that + // we can successfully create and destroy them independently. + macro_rules! test_for_each_key { + ($($key_t:ty, $key:expr;)*) => { + $( + let msg = TestAllTypes::new(); + let mut map = protobuf::Map::<$key_t, TestAllTypes>::new(); + map.as_mut().insert($key, msg); + assert_that!(map.as_view().len(), eq(1)); + )* + } + } + + test_for_each_key!( + i32, -5; + u32, 13u32; + i64, 7; + u64, 11u64; + bool, false; + ProtoString, "looooooooooooooooooooooooong string"; + ); +} + +#[test] +fn test_map_clearing_with_message_values() { + macro_rules! test_for_each_key { + ($($key_t:ty, $key:expr;)*) => { + $( + let msg = TestAllTypes::new(); + let mut map = protobuf::Map::<$key_t, TestAllTypes>::new(); + map.as_mut().insert($key, msg); + assert_that!(map.as_view().len(), eq(1)); + map.as_mut().clear(); + assert_that!(map.as_view().len(), eq(0)); + )* + } + } + + test_for_each_key!( + i32, -5; + u32, 13u32; + i64, 7; + u64, 11u64; + bool, false; + ProtoString, "looooooooooooooooooooooooong string"; + ); +} + macro_rules! generate_map_with_msg_values_tests { ( $(($k_field:ident, $k_nonzero:expr, $k_other:expr $(,)?)),* diff --git a/src/google/protobuf/compiler/rust/enum.cc b/src/google/protobuf/compiler/rust/enum.cc index 9b4a4402cf..8d5ccb54e9 100644 --- a/src/google/protobuf/compiler/rust/enum.cc +++ b/src/google/protobuf/compiler/rust/enum.cc @@ -81,7 +81,7 @@ void EnumProxiedInMapValue(Context& ctx, const EnumDescriptor& desc) { fn $map_get_thunk$(m: $pbr$::RawMap, key: $ffi_key_t$, value: *mut $name$) -> bool; fn $map_remove_thunk$(m: $pbr$::RawMap, key: $ffi_key_t$, value: *mut $name$) -> bool; fn $map_iter_thunk$(m: $pbr$::RawMap) -> $pbr$::UntypedMapIterator; - fn $map_iter_get_thunk$(iter: &mut $pbr$::UntypedMapIterator, key: *mut $ffi_key_t$, value: *mut $name$); + fn $map_iter_get_thunk$(iter: &mut $pbr$::UntypedMapIterator, size_info: $pbr$::MapNodeSizeInfo, key: *mut $ffi_key_t$, value: *mut $name$); } impl $pb$::ProxiedInMapValue<$key_t$> for $name$ { fn map_new(_private: $pbi$::Private) -> $pb$::Map<$key_t$, Self> { @@ -149,6 +149,7 @@ void EnumProxiedInMapValue(Context& ctx, const EnumDescriptor& desc) { iter.as_raw_mut($pbi$::Private).next_unchecked::<$key_t$, Self, _, _>( $pbi$::Private, $map_iter_get_thunk$, + $pbr$::MapNodeSizeInfo(0), // Ignored |ffi_key| $from_ffi_key_expr$, $std$::convert::identity, ) diff --git a/src/google/protobuf/compiler/rust/message.cc b/src/google/protobuf/compiler/rust/message.cc index d18a58deda..4665f44bf6 100644 --- a/src/google/protobuf/compiler/rust/message.cc +++ b/src/google/protobuf/compiler/rust/message.cc @@ -11,6 +11,7 @@ #include "absl/log/absl_check.h" #include "absl/log/absl_log.h" +#include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" #include "google/protobuf/compiler/cpp/helpers.h" #include "google/protobuf/compiler/cpp/names.h" @@ -197,29 +198,28 @@ void MessageExterns(Context& ctx, const Descriptor& msg) { switch (ctx.opts().kernel) { case Kernel::kCpp: ctx.Emit( - { - {"new_thunk", ThunkName(ctx, msg, "new")}, - {"delete_thunk", ThunkName(ctx, msg, "delete")}, - {"clear_thunk", ThunkName(ctx, msg, "clear")}, - {"serialize_thunk", ThunkName(ctx, msg, "serialize")}, - {"parse_thunk", ThunkName(ctx, msg, "parse")}, - {"copy_from_thunk", ThunkName(ctx, msg, "copy_from")}, - {"merge_from_thunk", ThunkName(ctx, msg, "merge_from")}, - {"repeated_new_thunk", ThunkName(ctx, msg, "repeated_new")}, - {"repeated_free_thunk", ThunkName(ctx, msg, "repeated_free")}, - {"repeated_len_thunk", ThunkName(ctx, msg, "repeated_len")}, - {"repeated_get_thunk", ThunkName(ctx, msg, "repeated_get")}, - {"repeated_get_mut_thunk", - ThunkName(ctx, msg, "repeated_get_mut")}, - {"repeated_add_thunk", ThunkName(ctx, msg, "repeated_add")}, - {"repeated_clear_thunk", ThunkName(ctx, msg, "repeated_clear")}, - {"repeated_copy_from_thunk", - ThunkName(ctx, msg, "repeated_copy_from")}, - {"repeated_reserve_thunk", - ThunkName(ctx, msg, "repeated_reserve")}, - }, + {{"new_thunk", ThunkName(ctx, msg, "new")}, + {"placement_new_thunk", ThunkName(ctx, msg, "placement_new")}, + {"delete_thunk", ThunkName(ctx, msg, "delete")}, + {"clear_thunk", ThunkName(ctx, msg, "clear")}, + {"serialize_thunk", ThunkName(ctx, msg, "serialize")}, + {"parse_thunk", ThunkName(ctx, msg, "parse")}, + {"copy_from_thunk", ThunkName(ctx, msg, "copy_from")}, + {"merge_from_thunk", ThunkName(ctx, msg, "merge_from")}, + {"repeated_new_thunk", ThunkName(ctx, msg, "repeated_new")}, + {"repeated_free_thunk", ThunkName(ctx, msg, "repeated_free")}, + {"repeated_len_thunk", ThunkName(ctx, msg, "repeated_len")}, + {"repeated_get_thunk", ThunkName(ctx, msg, "repeated_get")}, + {"repeated_get_mut_thunk", ThunkName(ctx, msg, "repeated_get_mut")}, + {"repeated_add_thunk", ThunkName(ctx, msg, "repeated_add")}, + {"repeated_clear_thunk", ThunkName(ctx, msg, "repeated_clear")}, + {"repeated_copy_from_thunk", + ThunkName(ctx, msg, "repeated_copy_from")}, + {"repeated_reserve_thunk", ThunkName(ctx, msg, "repeated_reserve")}, + {"map_size_info_thunk", ThunkName(ctx, msg, "size_info")}}, R"rs( fn $new_thunk$() -> $pbr$::RawMessage; + fn $placement_new_thunk$(ptr: *mut std::ffi::c_void, m: $pbr$::RawMessage); fn $delete_thunk$(raw_msg: $pbr$::RawMessage); fn $clear_thunk$(raw_msg: $pbr$::RawMessage); fn $serialize_thunk$(raw_msg: $pbr$::RawMessage, out: &mut $pbr$::SerializedData) -> bool; @@ -235,6 +235,7 @@ void MessageExterns(Context& ctx, const Descriptor& msg) { fn $repeated_clear_thunk$(raw: $pbr$::RawRepeatedField); fn $repeated_copy_from_thunk$(dst: $pbr$::RawRepeatedField, src: $pbr$::RawRepeatedField); fn $repeated_reserve_thunk$(raw: $pbr$::RawRepeatedField, additional: usize); + fn $map_size_info_thunk$(i: $pbr$::MapNodeSizeInfoIndex) -> $pbr$::MapNodeSizeInfo; )rs"); return; @@ -590,19 +591,18 @@ void MessageProxiedInMapValue(Context& ctx, const Descriptor& msg) { case Kernel::kCpp: for (const auto& t : kMapKeyTypes) { ctx.Emit( - {{"map_new_thunk", RawMapThunk(ctx, msg, t.thunk_ident, "new")}, - {"map_free_thunk", RawMapThunk(ctx, msg, t.thunk_ident, "free")}, - {"map_clear_thunk", RawMapThunk(ctx, msg, t.thunk_ident, "clear")}, - {"map_size_thunk", RawMapThunk(ctx, msg, t.thunk_ident, "size")}, - {"map_insert_thunk", - RawMapThunk(ctx, msg, t.thunk_ident, "insert")}, - {"map_get_thunk", RawMapThunk(ctx, msg, t.thunk_ident, "get")}, - {"map_remove_thunk", - RawMapThunk(ctx, msg, t.thunk_ident, "remove")}, - {"map_iter_thunk", RawMapThunk(ctx, msg, t.thunk_ident, "iter")}, - {"map_iter_get_thunk", - RawMapThunk(ctx, msg, t.thunk_ident, "iter_get")}, + {{"map_size_info_thunk", ThunkName(ctx, msg, "size_info")}, + {"placement_new_thunk", ThunkName(ctx, msg, "placement_new")}, + {"map_insert", + absl::StrCat("proto2_rust_map_insert_", t.thunk_ident)}, + {"map_remove", + absl::StrCat("proto2_rust_map_remove_", t.thunk_ident)}, + {"map_get", absl::StrCat("proto2_rust_map_get_", t.thunk_ident)}, + {"map_iter_get", + absl::StrCat("proto2_rust_map_iter_get_", t.thunk_ident)}, {"key_expr", t.rs_to_ffi_key_expr}, + {"key_is_string", + t.thunk_ident == "ProtoString" ? "true" : "false"}, io::Printer::Sub("ffi_key_t", [&] { ctx.Emit(t.rs_ffi_key_t); }) .WithSuffix(""), io::Printer::Sub("key_t", [&] { ctx.Emit(t.rs_key_t); }) @@ -611,47 +611,52 @@ void MessageProxiedInMapValue(Context& ctx, const Descriptor& msg) { [&] { ctx.Emit(t.rs_from_ffi_key_expr); }) .WithSuffix("")}, R"rs( - extern "C" { - fn $map_new_thunk$() -> $pbr$::RawMap; - fn $map_free_thunk$(m: $pbr$::RawMap); - fn $map_clear_thunk$(m: $pbr$::RawMap); - fn $map_size_thunk$(m: $pbr$::RawMap) -> usize; - fn $map_insert_thunk$(m: $pbr$::RawMap, key: $ffi_key_t$, value: $pbr$::RawMessage) -> bool; - fn $map_get_thunk$(m: $pbr$::RawMap, key: $ffi_key_t$, value: *mut $pbr$::RawMessage) -> bool; - fn $map_remove_thunk$(m: $pbr$::RawMap, key: $ffi_key_t$, value: *mut $pbr$::RawMessage) -> bool; - fn $map_iter_thunk$(m: $pbr$::RawMap) -> $pbr$::UntypedMapIterator; - fn $map_iter_get_thunk$(iter: &mut $pbr$::UntypedMapIterator, key: *mut $ffi_key_t$, value: *mut $pbr$::RawMessage); - } impl $pb$::ProxiedInMapValue<$key_t$> for $Msg$ { fn map_new(_private: $pbi$::Private) -> $pb$::Map<$key_t$, Self> { unsafe { $pb$::Map::from_inner( $pbi$::Private, - $pbr$::InnerMap::new($pbi$::Private, $map_new_thunk$()) + $pbr$::InnerMap::new($pbi$::Private, $pbr$::proto2_rust_map_new()) ) } } unsafe fn map_free(_private: $pbi$::Private, map: &mut $pb$::Map<$key_t$, Self>) { - unsafe { $map_free_thunk$(map.as_raw($pbi$::Private)); } + use $pbr$::MapNodeSizeInfoIndexForType; + unsafe { $pbr$::proto2_rust_map_free(map.as_raw($pbi$::Private), $key_is_string$, $map_size_info_thunk$($key_t$::SIZE_INFO_INDEX)); } } fn map_clear(mut map: $pb$::Mut<'_, $pb$::Map<$key_t$, Self>>) { - unsafe { $map_clear_thunk$(map.as_raw($pbi$::Private)); } + use $pbr$::MapNodeSizeInfoIndexForType; + unsafe { $pbr$::proto2_rust_map_clear(map.as_raw($pbi$::Private), $key_is_string$, $map_size_info_thunk$($key_t$::SIZE_INFO_INDEX)); } } fn map_len(map: $pb$::View<'_, $pb$::Map<$key_t$, Self>>) -> usize { - unsafe { $map_size_thunk$(map.as_raw($pbi$::Private)) } + unsafe { $pbr$::proto2_rust_map_size(map.as_raw($pbi$::Private)) } } fn map_insert(mut map: $pb$::Mut<'_, $pb$::Map<$key_t$, Self>>, key: $pb$::View<'_, $key_t$>, value: impl $pb$::IntoProxied) -> bool { - unsafe { $map_insert_thunk$(map.as_raw($pbi$::Private), $key_expr$, value.into_proxied($pbi$::Private).raw_msg()) } + use $pbr$::MapNodeSizeInfoIndexForType; + unsafe { + $pbr$::$map_insert$( + map.as_raw($pbi$::Private), + $map_size_info_thunk$($key_t$::SIZE_INFO_INDEX), + $key_expr$, + value.into_proxied($pbi$::Private).raw_msg(), $placement_new_thunk$) + } } fn map_get<'a>(map: $pb$::View<'a, $pb$::Map<$key_t$, Self>>, key: $pb$::View<'_, $key_t$>) -> Option<$pb$::View<'a, Self>> { + use $pbr$::MapNodeSizeInfoIndexForType; let key = $key_expr$; let mut value = $std$::mem::MaybeUninit::uninit(); - let found = unsafe { $map_get_thunk$(map.as_raw($pbi$::Private), key, value.as_mut_ptr()) }; + let found = unsafe { + $pbr$::$map_get$( + map.as_raw($pbi$::Private), + $map_size_info_thunk$($key_t$::SIZE_INFO_INDEX), + key, + value.as_mut_ptr()) + }; if !found { return None; } @@ -659,8 +664,13 @@ void MessageProxiedInMapValue(Context& ctx, const Descriptor& msg) { } fn map_remove(mut map: $pb$::Mut<'_, $pb$::Map<$key_t$, Self>>, key: $pb$::View<'_, $key_t$>) -> bool { - let mut value = $std$::mem::MaybeUninit::uninit(); - unsafe { $map_remove_thunk$(map.as_raw($pbi$::Private), $key_expr$, value.as_mut_ptr()) } + use $pbr$::MapNodeSizeInfoIndexForType; + unsafe { + $pbr$::$map_remove$( + map.as_raw($pbi$::Private), + $map_size_info_thunk$($key_t$::SIZE_INFO_INDEX), + $key_expr$) + } } fn map_iter(map: $pb$::View<'_, $pb$::Map<$key_t$, Self>>) -> $pb$::MapIter<'_, $key_t$, Self> { @@ -672,12 +682,13 @@ void MessageProxiedInMapValue(Context& ctx, const Descriptor& msg) { unsafe { $pb$::MapIter::from_raw( $pbi$::Private, - $map_iter_thunk$(map.as_raw($pbi$::Private)) + $pbr$::proto2_rust_map_iter(map.as_raw($pbi$::Private)) ) } } fn map_iter_next<'a>(iter: &mut $pb$::MapIter<'a, $key_t$, Self>) -> Option<($pb$::View<'a, $key_t$>, $pb$::View<'a, Self>)> { + use $pbr$::MapNodeSizeInfoIndexForType; // SAFETY: // - The `MapIter` API forbids the backing map from being mutated for 'a, // and guarantees that it's the correct key and value types. @@ -687,7 +698,8 @@ void MessageProxiedInMapValue(Context& ctx, const Descriptor& msg) { unsafe { iter.as_raw_mut($pbi$::Private).next_unchecked::<$key_t$, Self, _, _>( $pbi$::Private, - $map_iter_get_thunk$, + $pbr$::$map_iter_get$, + $map_size_info_thunk$($key_t$::SIZE_INFO_INDEX), |ffi_key| $from_ffi_key_expr$, |raw_msg| $Msg$View::new($pbi$::Private, raw_msg) ) @@ -1390,6 +1402,7 @@ void GenerateThunksCc(Context& ctx, const Descriptor& msg) { {"Msg", RsSafeName(msg.name())}, {"QualifiedMsg", cpp::QualifiedClassName(&msg)}, {"new_thunk", ThunkName(ctx, msg, "new")}, + {"placement_new_thunk", ThunkName(ctx, msg, "placement_new")}, {"delete_thunk", ThunkName(ctx, msg, "delete")}, {"clear_thunk", ThunkName(ctx, msg, "clear")}, {"serialize_thunk", ThunkName(ctx, msg, "serialize")}, @@ -1405,6 +1418,7 @@ void GenerateThunksCc(Context& ctx, const Descriptor& msg) { {"repeated_clear_thunk", ThunkName(ctx, msg, "repeated_clear")}, {"repeated_copy_from_thunk", ThunkName(ctx, msg, "repeated_copy_from")}, {"repeated_reserve_thunk", ThunkName(ctx, msg, "repeated_reserve")}, + {"map_size_info_thunk", ThunkName(ctx, msg, "size_info")}, {"nested_msg_thunks", [&] { for (int i = 0; i < msg.nested_type_count(); ++i) { @@ -1427,12 +1441,15 @@ void GenerateThunksCc(Context& ctx, const Descriptor& msg) { } }}}, R"cc( - //~ $abi$ is a workaround for a syntax highlight bug in VSCode. However, - //~ that confuses clang-format (it refuses to keep the newline after - //~ `$abi${`). Disabling clang-format for the block. + //~ $abi$ is a workaround for a syntax highlight bug in VSCode. + // However, ~ that confuses clang-format (it refuses to keep the + // newline after ~ `$abi${`). Disabling clang-format for the block. // clang-format off extern $abi$ { void* $new_thunk$() { return new $QualifiedMsg$(); } + void $placement_new_thunk$(void* ptr, $QualifiedMsg$& m) { + new (ptr) $QualifiedMsg$(std::move(m)); + } void $delete_thunk$(void* ptr) { delete static_cast<$QualifiedMsg$*>(ptr); } void $clear_thunk$(void* ptr) { static_cast<$QualifiedMsg$*>(ptr)->Clear(); @@ -1490,89 +1507,32 @@ void GenerateThunksCc(Context& ctx, const Descriptor& msg) { size_t additional) { field->Reserve(field->size() + additional); } + google::protobuf::internal::MapNodeSizeInfoT $map_size_info_thunk$(int32_t i) { + static constexpr google::protobuf::internal::MapNodeSizeInfoT size_infos[] = {)cc" + // LINT.IfChange(size_info_mapping) + R"cc( + google::protobuf::internal::RustMapHelper::SizeInfo(), + google::protobuf::internal::RustMapHelper::SizeInfo(), + google::protobuf::internal::RustMapHelper::SizeInfo(), + google::protobuf::internal::RustMapHelper::SizeInfo() + )cc" + // LINT.ThenChange(//depot/google3/third_party/protobuf/rust/cpp.rs:size_info_mapping) + R"cc( + } + ; + return size_infos[i]; + } $accessor_thunks$ - $oneof_thunks$ + $oneof_thunks$ } // extern $abi$ // clang-format on $nested_msg_thunks$ )cc"); - for (const auto& t : kMapKeyTypes) { - ctx.Emit( - { - {"map_new_thunk", RawMapThunk(ctx, msg, t.thunk_ident, "new")}, - {"map_free_thunk", RawMapThunk(ctx, msg, t.thunk_ident, "free")}, - {"map_clear_thunk", RawMapThunk(ctx, msg, t.thunk_ident, "clear")}, - {"map_size_thunk", RawMapThunk(ctx, msg, t.thunk_ident, "size")}, - {"map_insert_thunk", - RawMapThunk(ctx, msg, t.thunk_ident, "insert")}, - {"map_get_thunk", RawMapThunk(ctx, msg, t.thunk_ident, "get")}, - {"map_remove_thunk", - RawMapThunk(ctx, msg, t.thunk_ident, "remove")}, - {"map_iter_thunk", RawMapThunk(ctx, msg, t.thunk_ident, "iter")}, - {"map_iter_get_thunk", - RawMapThunk(ctx, msg, t.thunk_ident, "iter_get")}, - {"key_t", t.cc_key_t}, - {"ffi_key_t", t.cc_ffi_key_t}, - {"key_expr", t.cc_from_ffi_key_expr}, - {"to_ffi_key_expr", t.cc_to_ffi_key_expr}, - {"pkg::Msg", cpp::QualifiedClassName(&msg)}, - {"abi", "\"C\""}, // Workaround for syntax highlight bug in VSCode. - }, - R"cc( - extern $abi$ { - const google::protobuf::Map<$key_t$, $pkg::Msg$>* $map_new_thunk$() { - return new google::protobuf::Map<$key_t$, $pkg::Msg$>(); - } - void $map_free_thunk$(const google::protobuf::Map<$key_t$, $pkg::Msg$>* m) { delete m; } - void $map_clear_thunk$(google::protobuf::Map<$key_t$, $pkg::Msg$> * m) { m->clear(); } - size_t $map_size_thunk$(const google::protobuf::Map<$key_t$, $pkg::Msg$>* m) { - return m->size(); - } - bool $map_insert_thunk$(google::protobuf::Map<$key_t$, $pkg::Msg$> * m, - $ffi_key_t$ key, $pkg::Msg$ value) { - auto k = $key_expr$; - auto it = m->find(k); - if (it != m->end()) { - return false; - } - (*m)[k] = value; - return true; - } - bool $map_get_thunk$(const google::protobuf::Map<$key_t$, $pkg::Msg$>* m, - $ffi_key_t$ key, const $pkg::Msg$** value) { - auto it = m->find($key_expr$); - if (it == m->end()) { - return false; - } - const $pkg::Msg$* cpp_value = &it->second; - *value = cpp_value; - return true; - } - bool $map_remove_thunk$(google::protobuf::Map<$key_t$, $pkg::Msg$> * m, - $ffi_key_t$ key, $pkg::Msg$ * value) { - auto num_removed = m->erase($key_expr$); - return num_removed > 0; - } - google::protobuf::internal::UntypedMapIterator $map_iter_thunk$( - const google::protobuf::Map<$key_t$, $pkg::Msg$>* m) { - return google::protobuf::internal::UntypedMapIterator::FromTyped(m->cbegin()); - } - void $map_iter_get_thunk$( - const google::protobuf::internal::UntypedMapIterator* iter, - $ffi_key_t$* key, const $pkg::Msg$** value) { - auto typed_iter = iter->ToTyped< - google::protobuf::Map<$key_t$, $pkg::Msg$>::const_iterator>(); - const auto& cpp_key = typed_iter->first; - const auto& cpp_value = typed_iter->second; - *key = $to_ffi_key_expr$; - *value = &cpp_value; - } - } - )cc"); - } } } // namespace rust diff --git a/src/google/protobuf/map.h b/src/google/protobuf/map.h index 0953b12d00..74cfa9d439 100644 --- a/src/google/protobuf/map.h +++ b/src/google/protobuf/map.h @@ -62,6 +62,10 @@ class MapIterator; template struct is_proto_enum; +namespace rust { +struct PtrAndLen; +} // namespace rust + namespace internal { template class MapFieldLite; @@ -590,6 +594,7 @@ class PROTOBUF_EXPORT UntypedMapBase { friend struct MapTestPeer; friend struct MapBenchmarkPeer; friend class UntypedMapIterator; + friend class RustMapHelper; struct NodeAndBucket { NodeBase* node; @@ -924,6 +929,7 @@ class KeyMapBase : public UntypedMapBase { friend class TcParser; friend struct MapTestPeer; friend struct MapBenchmarkPeer; + friend class RustMapHelper; PROTOBUF_NOINLINE void erase_no_destroy(map_index_t b, KeyNode* node) { TreeIterator tree_it; @@ -1139,6 +1145,60 @@ bool InitializeMapKey(T*, K&&, Arena*) { } +// The purpose of this class is to give the Rust implementation visibility into +// some of the internals of C++ proto maps. We need access to these internals +// to be able to implement Rust map operations without duplicating the same +// functionality for every message type. +class RustMapHelper { + public: + using NodeAndBucket = UntypedMapBase::NodeAndBucket; + using ClearInput = UntypedMapBase::ClearInput; + + template + static constexpr MapNodeSizeInfoT SizeInfo() { + return Map::Node::size_info(); + } + + enum { + kKeyIsString = UntypedMapBase::kKeyIsString, + kValueIsProto = UntypedMapBase::kValueIsProto, + }; + + static NodeBase* AllocNode(UntypedMapBase* m, MapNodeSizeInfoT size_info) { + return m->AllocNode(size_info); + } + + static void DeallocNode(UntypedMapBase* m, NodeBase* node, + MapNodeSizeInfoT size_info) { + return m->DeallocNode(node, size_info); + } + + template + static NodeAndBucket FindHelper(Map* m, Key key) { + return m->FindHelper(key); + } + + template + static typename Map::KeyNode* InsertOrReplaceNode(Map* m, NodeBase* node) { + return m->InsertOrReplaceNode(static_cast(node)); + } + + template + static void EraseNoDestroy(Map* m, map_index_t bucket, NodeBase* node) { + m->erase_no_destroy(bucket, static_cast(node)); + } + + static void DestroyMessage(MessageLite* m) { m->DestroyInstance(false); } + + static void ClearTable(UntypedMapBase* m, ClearInput input) { + m->ClearTable(input); + } + + static bool IsGlobalEmptyTable(const UntypedMapBase* m) { + return m->num_buckets_ == kGlobalEmptyTableSize; + } +}; + } // namespace internal // This is the class for Map's internal value_type. @@ -1252,6 +1312,11 @@ class Map : private internal::KeyMapBase> { internal::is_internal_map_value_type>::value, "We only support scalar, Message, and designated internal " "mapped types."); + // The Rust implementation that wraps C++ protos relies on the ability to + // create an UntypedMapBase and cast a pointer of it to google::protobuf::Map*. + static_assert( + sizeof(Map) == sizeof(internal::UntypedMapBase), + "Map must not have any data members beyond what is in UntypedMapBase."); } template @@ -1702,6 +1767,7 @@ class Map : private internal::KeyMapBase> { friend class internal::TcParser; friend struct internal::MapTestPeer; friend struct internal::MapBenchmarkPeer; + friend class internal::RustMapHelper; }; namespace internal { diff --git a/src/google/protobuf/message_lite.h b/src/google/protobuf/message_lite.h index 9e462fa5f9..94e18c1284 100644 --- a/src/google/protobuf/message_lite.h +++ b/src/google/protobuf/message_lite.h @@ -158,6 +158,7 @@ class TcParser; struct TcParseTableBase; class WireFormatLite; class WeakFieldMap; +class RustMapHelper; template class GenericTypeHandler; // defined in repeated_field.h @@ -857,6 +858,7 @@ class PROTOBUF_EXPORT MessageLite { friend class internal::UntypedMapBase; friend class internal::WeakFieldMap; friend class internal::WireFormatLite; + friend class internal::RustMapHelper; template friend class Arena::InternalHelper;