diff --git a/rust/cpp.rs b/rust/cpp.rs index 2f8fb90e86..661b3c5613 100644 --- a/rust/cpp.rs +++ b/rust/cpp.rs @@ -812,7 +812,6 @@ macro_rules! impl_map_primitives { size_info: MapNodeSizeInfo, key: $cpp_type, value: RawMessage, - placement_new: unsafe extern "C" fn(*mut c_void, m: RawMessage), ) -> bool; pub fn $get_thunk( m: RawMap, diff --git a/rust/cpp_kernel/map.cc b/rust/cpp_kernel/map.cc index 2610836a28..45b996b92d 100644 --- a/rust/cpp_kernel/map.cc +++ b/rust/cpp_kernel/map.cc @@ -7,6 +7,7 @@ #include #include "google/protobuf/map.h" +#include "google/protobuf/message.h" #include "google/protobuf/message_lite.h" #include "rust/cpp_kernel/strings.h" @@ -42,16 +43,27 @@ void DestroyMapNode(internal::UntypedMapBase* m, internal::NodeBase* node, template bool Insert(internal::UntypedMapBase* m, internal::MapNodeSizeInfoT size_info, - Key key, MessageLite* value, - void (*placement_new)(void*, MessageLite*)) { + Key key, MessageLite* value) { internal::NodeBase* node = internal::RustMapHelper::AllocNode(m, size_info); if constexpr (std::is_same::value) { new (node->GetVoidKey()) std::string(key.ptr, key.len); } else { *static_cast(node->GetVoidKey()) = key; } - void* value_ptr = node->GetVoidValue(size_info); - placement_new(value_ptr, value); + + MessageLite* new_msg = internal::RustMapHelper::PlacementNew( + value, node->GetVoidValue(size_info)); + auto* full_msg = DynamicCastMessage(new_msg); + + // If we are working with a full (non-lite) proto, we reflectively swap the + // value into place. Otherwise, we have to perform a copy. + if (full_msg != nullptr) { + full_msg->GetReflection()->Swap(full_msg, + DynamicCastMessage(value)); + } else { + new_msg->CheckTypeAndMergeFrom(*value); + } + node = internal::RustMapHelper::InsertOrReplaceNode( static_cast*>(m), node); if (node == nullptr) { @@ -166,33 +178,32 @@ google::protobuf::internal::UntypedMapIterator proto2_rust_map_iter( return m->begin(); } -#define DEFINE_KEY_SPECIFIC_MAP_OPERATIONS(cpp_type, suffix) \ - bool proto2_rust_map_insert_##suffix( \ - google::protobuf::internal::UntypedMapBase* m, \ - google::protobuf::internal::MapNodeSizeInfoT size_info, cpp_type key, \ - google::protobuf::MessageLite* value, \ - void (*placement_new)(void*, google::protobuf::MessageLite*)) { \ - return google::protobuf::rust::Insert(m, size_info, key, value, placement_new); \ - } \ - \ - bool proto2_rust_map_get_##suffix( \ - google::protobuf::internal::UntypedMapBase* m, \ - google::protobuf::internal::MapNodeSizeInfoT size_info, cpp_type key, \ - google::protobuf::MessageLite** value) { \ - return google::protobuf::rust::Get(m, size_info, key, value); \ - } \ - \ - bool proto2_rust_map_remove_##suffix( \ - google::protobuf::internal::UntypedMapBase* m, \ - google::protobuf::internal::MapNodeSizeInfoT size_info, cpp_type key) { \ - return google::protobuf::rust::Remove(m, size_info, key); \ - } \ - \ - void proto2_rust_map_iter_get_##suffix( \ - const google::protobuf::internal::UntypedMapIterator* iter, \ - google::protobuf::internal::MapNodeSizeInfoT size_info, cpp_type* key, \ - google::protobuf::MessageLite** value) { \ - return google::protobuf::rust::IterGet(iter, size_info, key, value); \ +#define DEFINE_KEY_SPECIFIC_MAP_OPERATIONS(cpp_type, suffix) \ + bool proto2_rust_map_insert_##suffix( \ + google::protobuf::internal::UntypedMapBase* m, \ + google::protobuf::internal::MapNodeSizeInfoT size_info, cpp_type key, \ + google::protobuf::MessageLite* value) { \ + return google::protobuf::rust::Insert(m, size_info, key, value); \ + } \ + \ + bool proto2_rust_map_get_##suffix( \ + google::protobuf::internal::UntypedMapBase* m, \ + google::protobuf::internal::MapNodeSizeInfoT size_info, cpp_type key, \ + google::protobuf::MessageLite** value) { \ + return google::protobuf::rust::Get(m, size_info, key, value); \ + } \ + \ + bool proto2_rust_map_remove_##suffix( \ + google::protobuf::internal::UntypedMapBase* m, \ + google::protobuf::internal::MapNodeSizeInfoT size_info, cpp_type key) { \ + return google::protobuf::rust::Remove(m, size_info, key); \ + } \ + \ + void proto2_rust_map_iter_get_##suffix( \ + const google::protobuf::internal::UntypedMapIterator* iter, \ + google::protobuf::internal::MapNodeSizeInfoT size_info, cpp_type* key, \ + google::protobuf::MessageLite** value) { \ + return google::protobuf::rust::IterGet(iter, size_info, key, value); \ } DEFINE_KEY_SPECIFIC_MAP_OPERATIONS(int32_t, i32) diff --git a/src/google/protobuf/compiler/rust/message.cc b/src/google/protobuf/compiler/rust/message.cc index 91b0a45e4c..6bdde9a95c 100644 --- a/src/google/protobuf/compiler/rust/message.cc +++ b/src/google/protobuf/compiler/rust/message.cc @@ -195,7 +195,6 @@ void CppMessageExterns(Context& ctx, const Descriptor& msg) { ABSL_CHECK(ctx.is_cpp()); ctx.Emit( {{"new_thunk", ThunkName(ctx, msg, "new")}, - {"placement_new_thunk", ThunkName(ctx, msg, "placement_new")}, {"repeated_new_thunk", ThunkName(ctx, msg, "repeated_new")}, {"repeated_free_thunk", ThunkName(ctx, msg, "repeated_free")}, {"repeated_len_thunk", ThunkName(ctx, msg, "repeated_len")}, @@ -208,7 +207,6 @@ void CppMessageExterns(Context& ctx, const Descriptor& msg) { {"map_size_info_thunk", ThunkName(ctx, msg, "size_info")}}, R"rs( fn $new_thunk$() -> $pbr$::RawMessage; - fn $placement_new_thunk$(ptr: *mut $std$::ffi::c_void, m: $pbr$::RawMessage); fn $repeated_new_thunk$() -> $pbr$::RawRepeatedField; fn $repeated_free_thunk$(raw: $pbr$::RawRepeatedField); fn $repeated_len_thunk$(raw: $pbr$::RawRepeatedField) -> usize; @@ -566,7 +564,6 @@ void MessageProxiedInMapValue(Context& ctx, const Descriptor& msg) { for (const auto& t : kMapKeyTypes) { ctx.Emit( {{"map_size_info_thunk", ThunkName(ctx, msg, "size_info")}, - {"placement_new_thunk", ThunkName(ctx, msg, "placement_new")}, {"map_insert", absl::StrCat("proto2_rust_map_insert_", t.thunk_ident)}, {"map_remove", @@ -616,7 +613,7 @@ void MessageProxiedInMapValue(Context& ctx, const Descriptor& msg) { map.as_raw($pbi$::Private), $map_size_info_thunk$($key_t$::SIZE_INFO_INDEX), $key_expr$, - value.into_proxied($pbi$::Private).raw_msg(), $placement_new_thunk$) + value.into_proxied($pbi$::Private).raw_msg()) } } @@ -1355,7 +1352,6 @@ void GenerateThunksCc(Context& ctx, const Descriptor& msg) { {"Msg", RsSafeName(msg.name())}, {"QualifiedMsg", cpp::QualifiedClassName(&msg)}, {"new_thunk", ThunkName(ctx, msg, "new")}, - {"placement_new_thunk", ThunkName(ctx, msg, "placement_new")}, {"repeated_new_thunk", ThunkName(ctx, msg, "repeated_new")}, {"repeated_free_thunk", ThunkName(ctx, msg, "repeated_free")}, {"repeated_len_thunk", ThunkName(ctx, msg, "repeated_len")}, @@ -1391,9 +1387,6 @@ void GenerateThunksCc(Context& ctx, const Descriptor& msg) { // clang-format off extern $abi$ { void* $new_thunk$() { return new $QualifiedMsg$(); } - void $placement_new_thunk$(void* ptr, $QualifiedMsg$& m) { - new (ptr) $QualifiedMsg$(std::move(m)); - } void* $repeated_new_thunk$() { return new google::protobuf::RepeatedPtrField<$QualifiedMsg$>(); diff --git a/src/google/protobuf/map.h b/src/google/protobuf/map.h index 4e44456273..a65f8f2d63 100644 --- a/src/google/protobuf/map.h +++ b/src/google/protobuf/map.h @@ -1190,6 +1190,11 @@ class RustMapHelper { m->erase_no_destroy(bucket, static_cast(node)); } + static google::protobuf::MessageLite* PlacementNew(const MessageLite* prototype, + void* mem) { + return prototype->GetClassData()->PlacementNew(mem, /* arena = */ nullptr); + } + static void DestroyMessage(MessageLite* m) { m->DestroyInstance(); } static void ClearTable(UntypedMapBase* m, ClearInput input) {