From 7f395af40e86e00b892e812beb67a03564884756 Mon Sep 17 00:00:00 2001 From: Adam Cozzette <acozzette@google.com> Date: Wed, 21 Aug 2024 09:26:18 -0700 Subject: [PATCH] Replace some per-message C++ thunks with a common implementation This change adds delete, clear, serialize, parse, copy_from, and merge_from operations to the runtime. Since these operations can all be implemented easily on the `MessageLite` interface, we can use a common implementation in the runtime instead of generating per-message thunks for all of these. I suspect this will also make it possible to remove some of our generated trait implementations and replace them with blanket implementations, but I will leave that for a future change. PiperOrigin-RevId: 665910927 --- rust/cpp.rs | 9 +++ rust/cpp_kernel/BUILD | 1 + rust/cpp_kernel/message.cc | 36 ++++++++++ src/google/protobuf/compiler/rust/message.cc | 76 +++++--------------- 4 files changed, 62 insertions(+), 60 deletions(-) create mode 100644 rust/cpp_kernel/message.cc diff --git a/rust/cpp.rs b/rust/cpp.rs index 3a6fdad664..2c963408a8 100644 --- a/rust/cpp.rs +++ b/rust/cpp.rs @@ -95,6 +95,15 @@ pub struct InnerProtoString { owned_ptr: CppStdString, } +extern "C" { + pub fn proto2_rust_Message_delete(m: RawMessage); + pub fn proto2_rust_Message_clear(m: RawMessage); + pub fn proto2_rust_Message_parse(m: RawMessage, input: SerializedData) -> bool; + pub fn proto2_rust_Message_serialize(m: RawMessage, output: &mut SerializedData) -> bool; + pub fn proto2_rust_Message_copy_from(dst: RawMessage, src: RawMessage) -> bool; + pub fn proto2_rust_Message_merge_from(dst: RawMessage, src: RawMessage) -> bool; +} + /// An opaque type matching MapNodeSizeInfoT from C++. #[doc(hidden)] #[repr(transparent)] diff --git a/rust/cpp_kernel/BUILD b/rust/cpp_kernel/BUILD index c3fbe90a71..041b0c7a87 100644 --- a/rust/cpp_kernel/BUILD +++ b/rust/cpp_kernel/BUILD @@ -8,6 +8,7 @@ cc_library( "compare.cc", "debug.cc", "map.cc", + "message.cc", "repeated.cc", "strings.cc", ], diff --git a/rust/cpp_kernel/message.cc b/rust/cpp_kernel/message.cc new file mode 100644 index 0000000000..08c5f484d6 --- /dev/null +++ b/rust/cpp_kernel/message.cc @@ -0,0 +1,36 @@ +#include <limits> + +#include "google/protobuf/message_lite.h" +#include "rust/cpp_kernel/serialized_data.h" + +extern "C" { + +void proto2_rust_Message_delete(google::protobuf::MessageLite* m) { delete m; } + +void proto2_rust_Message_clear(google::protobuf::MessageLite* m) { m->Clear(); } + +bool proto2_rust_Message_parse(google::protobuf::MessageLite* m, + google::protobuf::rust::SerializedData input) { + if (input.len > std::numeric_limits<int>::max()) { + return false; + } + return m->ParseFromArray(input.data, static_cast<int>(input.len)); +} + +bool proto2_rust_Message_serialize(const google::protobuf::MessageLite* m, + google::protobuf::rust::SerializedData* output) { + return google::protobuf::rust::SerializeMsg(m, output); +} + +void proto2_rust_Message_copy_from(google::protobuf::MessageLite* dst, + const google::protobuf::MessageLite& src) { + dst->Clear(); + dst->CheckTypeAndMergeFrom(src); +} + +void proto2_rust_Message_merge_from(google::protobuf::MessageLite* dst, + const google::protobuf::MessageLite& src) { + dst->CheckTypeAndMergeFrom(src); +} + +} // extern "C" diff --git a/src/google/protobuf/compiler/rust/message.cc b/src/google/protobuf/compiler/rust/message.cc index d2dcb534ff..ea1fddaa3a 100644 --- a/src/google/protobuf/compiler/rust/message.cc +++ b/src/google/protobuf/compiler/rust/message.cc @@ -61,10 +61,10 @@ void MessageNew(Context& ctx, const Descriptor& msg) { void MessageSerialize(Context& ctx, const Descriptor& msg) { switch (ctx.opts().kernel) { case Kernel::kCpp: - ctx.Emit({{"serialize_thunk", ThunkName(ctx, msg, "serialize")}}, R"rs( + ctx.Emit({}, R"rs( let mut serialized_data = $pbr$::SerializedData::new($pbi$::Private); let success = unsafe { - $serialize_thunk$(self.raw_msg(), &mut serialized_data) + $pbr$::proto2_rust_Message_serialize(self.raw_msg(), &mut serialized_data) }; if success { Ok(serialized_data.into_vec()) @@ -95,9 +95,9 @@ void MessageSerialize(Context& ctx, const Descriptor& msg) { void MessageMutClear(Context& ctx, const Descriptor& msg) { switch (ctx.opts().kernel) { case Kernel::kCpp: - ctx.Emit({{"clear_thunk", ThunkName(ctx, msg, "clear")}}, + ctx.Emit({}, R"rs( - unsafe { $clear_thunk$(self.raw_msg()) } + unsafe { $pbr$::proto2_rust_Message_clear(self.raw_msg()) } )rs"); return; case Kernel::kUpb: @@ -116,11 +116,8 @@ void MessageMutClear(Context& ctx, const Descriptor& msg) { void MessageClearAndParse(Context& ctx, const Descriptor& msg) { switch (ctx.opts().kernel) { case Kernel::kCpp: - ctx.Emit( - { - {"parse_thunk", ThunkName(ctx, msg, "parse")}, - }, - R"rs( + ctx.Emit({}, + R"rs( let success = unsafe { // SAFETY: `data.as_ptr()` is valid to read for `data.len()`. let data = $pbr$::SerializedData::from_raw_parts( @@ -128,7 +125,7 @@ void MessageClearAndParse(Context& ctx, const Descriptor& msg) { data.len(), ); - $parse_thunk$(self.raw_msg(), data) + $pbr$::proto2_rust_Message_parse(self.raw_msg(), data) }; success.then_some(()).ok_or($pb$::ParseError) )rs"); @@ -199,12 +196,6 @@ void MessageExterns(Context& ctx, const Descriptor& msg) { ctx.Emit( {{"new_thunk", ThunkName(ctx, msg, "new")}, {"placement_new_thunk", ThunkName(ctx, msg, "placement_new")}, - {"delete_thunk", ThunkName(ctx, msg, "delete")}, - {"clear_thunk", ThunkName(ctx, msg, "clear")}, - {"serialize_thunk", ThunkName(ctx, msg, "serialize")}, - {"parse_thunk", ThunkName(ctx, msg, "parse")}, - {"copy_from_thunk", ThunkName(ctx, msg, "copy_from")}, - {"merge_from_thunk", ThunkName(ctx, msg, "merge_from")}, {"repeated_new_thunk", ThunkName(ctx, msg, "repeated_new")}, {"repeated_free_thunk", ThunkName(ctx, msg, "repeated_free")}, {"repeated_len_thunk", ThunkName(ctx, msg, "repeated_len")}, @@ -219,12 +210,6 @@ void MessageExterns(Context& ctx, const Descriptor& msg) { R"rs( fn $new_thunk$() -> $pbr$::RawMessage; fn $placement_new_thunk$(ptr: *mut std::ffi::c_void, m: $pbr$::RawMessage); - fn $delete_thunk$(raw_msg: $pbr$::RawMessage); - fn $clear_thunk$(raw_msg: $pbr$::RawMessage); - fn $serialize_thunk$(raw_msg: $pbr$::RawMessage, out: &mut $pbr$::SerializedData) -> bool; - fn $parse_thunk$(raw_msg: $pbr$::RawMessage, data: $pbr$::SerializedData) -> bool; - fn $copy_from_thunk$(dst: $pbr$::RawMessage, src: $pbr$::RawMessage); - fn $merge_from_thunk$(dst: $pbr$::RawMessage, src: $pbr$::RawMessage); fn $repeated_new_thunk$() -> $pbr$::RawRepeatedField; fn $repeated_free_thunk$(raw: $pbr$::RawRepeatedField); fn $repeated_len_thunk$(raw: $pbr$::RawRepeatedField) -> usize; @@ -263,19 +248,19 @@ void MessageDrop(Context& ctx, const Descriptor& msg) { return; } - ctx.Emit({{"delete_thunk", ThunkName(ctx, msg, "delete")}}, R"rs( - unsafe { $delete_thunk$(self.raw_msg()); } + ctx.Emit({}, R"rs( + unsafe { $pbr$::proto2_rust_Message_delete(self.raw_msg()); } )rs"); } void IntoProxiedForMessage(Context& ctx, const Descriptor& msg) { switch (ctx.opts().kernel) { case Kernel::kCpp: - ctx.Emit({{"copy_from_thunk", ThunkName(ctx, msg, "copy_from")}}, R"rs( + ctx.Emit({}, R"rs( impl<'msg> $pb$::IntoProxied<$Msg$> for $Msg$View<'msg> { fn into_proxied(self, _private: $pbi$::Private) -> $Msg$ { let dst = $Msg$::new(); - unsafe { $copy_from_thunk$(dst.inner.msg, self.msg) }; + unsafe { $pbr$::proto2_rust_Message_copy_from(dst.inner.msg, self.msg) }; dst } } @@ -345,16 +330,13 @@ void UpbGeneratedMessageTraitImpls(Context& ctx, const Descriptor& msg) { void MessageMutMergeFrom(Context& ctx, const Descriptor& msg) { switch (ctx.opts().kernel) { case Kernel::kCpp: - ctx.Emit( - { - {"merge_from_thunk", ThunkName(ctx, msg, "merge_from")}, - }, - R"rs( + ctx.Emit({}, + R"rs( impl $pb$::MergeFrom for $Msg$Mut<'_> { fn merge_from(&mut self, src: impl $pb$::AsView<Proxied = $Msg$>) { // SAFETY: self and src are both valid `$Msg$`s. unsafe { - $merge_from_thunk$(self.raw_msg(), src.as_view().raw_msg()); + $pbr$::proto2_rust_Message_merge_from(self.raw_msg(), src.as_view().raw_msg()); } } } @@ -389,7 +371,6 @@ void MessageProxiedInRepeated(Context& ctx, const Descriptor& msg) { ctx.Emit( { {"Msg", RsSafeName(msg.name())}, - {"copy_from_thunk", ThunkName(ctx, msg, "copy_from")}, {"repeated_len_thunk", ThunkName(ctx, msg, "repeated_len")}, {"repeated_get_thunk", ThunkName(ctx, msg, "repeated_get")}, {"repeated_get_mut_thunk", @@ -438,7 +419,7 @@ void MessageProxiedInRepeated(Context& ctx, const Descriptor& msg) { // - `i < len(f)` is promised by caller. // - `v.raw_msg()` is a valid `const Message&`. unsafe { - $copy_from_thunk$( + $pbr$::proto2_rust_Message_copy_from( $repeated_get_mut_thunk$(f.as_raw($pbi$::Private), i), v.into_proxied($pbi$::Private).raw_msg(), ); @@ -467,7 +448,7 @@ void MessageProxiedInRepeated(Context& ctx, const Descriptor& msg) { // - `v.raw_msg()` is a valid `const Message&`. unsafe { let new_elem = $repeated_add_thunk$(f.as_raw($pbi$::Private)); - $copy_from_thunk$(new_elem, v.into_proxied($pbi$::Private).raw_msg()); + $pbr$::proto2_rust_Message_copy_from(new_elem, v.into_proxied($pbi$::Private).raw_msg()); } } @@ -1406,12 +1387,6 @@ void GenerateThunksCc(Context& ctx, const Descriptor& msg) { {"QualifiedMsg", cpp::QualifiedClassName(&msg)}, {"new_thunk", ThunkName(ctx, msg, "new")}, {"placement_new_thunk", ThunkName(ctx, msg, "placement_new")}, - {"delete_thunk", ThunkName(ctx, msg, "delete")}, - {"clear_thunk", ThunkName(ctx, msg, "clear")}, - {"serialize_thunk", ThunkName(ctx, msg, "serialize")}, - {"parse_thunk", ThunkName(ctx, msg, "parse")}, - {"copy_from_thunk", ThunkName(ctx, msg, "copy_from")}, - {"merge_from_thunk", ThunkName(ctx, msg, "merge_from")}, {"repeated_new_thunk", ThunkName(ctx, msg, "repeated_new")}, {"repeated_free_thunk", ThunkName(ctx, msg, "repeated_free")}, {"repeated_len_thunk", ThunkName(ctx, msg, "repeated_len")}, @@ -1453,25 +1428,6 @@ void GenerateThunksCc(Context& ctx, const Descriptor& msg) { void $placement_new_thunk$(void* ptr, $QualifiedMsg$& m) { new (ptr) $QualifiedMsg$(std::move(m)); } - void $delete_thunk$(void* ptr) { delete static_cast<$QualifiedMsg$*>(ptr); } - void $clear_thunk$(void* ptr) { - static_cast<$QualifiedMsg$*>(ptr)->Clear(); - } - bool $serialize_thunk$($QualifiedMsg$* msg, google::protobuf::rust::SerializedData* out) { - return google::protobuf::rust::SerializeMsg(msg, out); - } - bool $parse_thunk$($QualifiedMsg$* msg, - google::protobuf::rust::SerializedData data) { - return msg->ParseFromArray(data.data, data.len); - } - - void $copy_from_thunk$($QualifiedMsg$* dst, const $QualifiedMsg$* src) { - dst->CopyFrom(*src); - } - - void $merge_from_thunk$($QualifiedMsg$* dst, const $QualifiedMsg$* src) { - dst->MergeFrom(*src); - } void* $repeated_new_thunk$() { return new google::protobuf::RepeatedPtrField<$QualifiedMsg$>();