From 7f395af40e86e00b892e812beb67a03564884756 Mon Sep 17 00:00:00 2001
From: Adam Cozzette <acozzette@google.com>
Date: Wed, 21 Aug 2024 09:26:18 -0700
Subject: [PATCH] Replace some per-message C++ thunks with a common
 implementation

This change adds delete, clear, serialize, parse, copy_from, and merge_from
operations to the runtime. Since these operations can all be implemented easily
on the `MessageLite` interface, we can use a common implementation in the
runtime instead of generating per-message thunks for all of these.

I suspect this will also make it possible to remove some of our generated trait
implementations and replace them with blanket implementations, but I will leave
that for a future change.

PiperOrigin-RevId: 665910927
---
 rust/cpp.rs                                  |  9 +++
 rust/cpp_kernel/BUILD                        |  1 +
 rust/cpp_kernel/message.cc                   | 36 ++++++++++
 src/google/protobuf/compiler/rust/message.cc | 76 +++++---------------
 4 files changed, 62 insertions(+), 60 deletions(-)
 create mode 100644 rust/cpp_kernel/message.cc

diff --git a/rust/cpp.rs b/rust/cpp.rs
index 3a6fdad664..2c963408a8 100644
--- a/rust/cpp.rs
+++ b/rust/cpp.rs
@@ -95,6 +95,15 @@ pub struct InnerProtoString {
     owned_ptr: CppStdString,
 }
 
+extern "C" {
+    pub fn proto2_rust_Message_delete(m: RawMessage);
+    pub fn proto2_rust_Message_clear(m: RawMessage);
+    pub fn proto2_rust_Message_parse(m: RawMessage, input: SerializedData) -> bool;
+    pub fn proto2_rust_Message_serialize(m: RawMessage, output: &mut SerializedData) -> bool;
+    pub fn proto2_rust_Message_copy_from(dst: RawMessage, src: RawMessage) -> bool;
+    pub fn proto2_rust_Message_merge_from(dst: RawMessage, src: RawMessage) -> bool;
+}
+
 /// An opaque type matching MapNodeSizeInfoT from C++.
 #[doc(hidden)]
 #[repr(transparent)]
diff --git a/rust/cpp_kernel/BUILD b/rust/cpp_kernel/BUILD
index c3fbe90a71..041b0c7a87 100644
--- a/rust/cpp_kernel/BUILD
+++ b/rust/cpp_kernel/BUILD
@@ -8,6 +8,7 @@ cc_library(
         "compare.cc",
         "debug.cc",
         "map.cc",
+        "message.cc",
         "repeated.cc",
         "strings.cc",
     ],
diff --git a/rust/cpp_kernel/message.cc b/rust/cpp_kernel/message.cc
new file mode 100644
index 0000000000..08c5f484d6
--- /dev/null
+++ b/rust/cpp_kernel/message.cc
@@ -0,0 +1,36 @@
+#include <limits>
+
+#include "google/protobuf/message_lite.h"
+#include "rust/cpp_kernel/serialized_data.h"
+
+extern "C" {
+
+void proto2_rust_Message_delete(google::protobuf::MessageLite* m) { delete m; }
+
+void proto2_rust_Message_clear(google::protobuf::MessageLite* m) { m->Clear(); }
+
+bool proto2_rust_Message_parse(google::protobuf::MessageLite* m,
+                               google::protobuf::rust::SerializedData input) {
+  if (input.len > std::numeric_limits<int>::max()) {
+    return false;
+  }
+  return m->ParseFromArray(input.data, static_cast<int>(input.len));
+}
+
+bool proto2_rust_Message_serialize(const google::protobuf::MessageLite* m,
+                                   google::protobuf::rust::SerializedData* output) {
+  return google::protobuf::rust::SerializeMsg(m, output);
+}
+
+void proto2_rust_Message_copy_from(google::protobuf::MessageLite* dst,
+                                   const google::protobuf::MessageLite& src) {
+  dst->Clear();
+  dst->CheckTypeAndMergeFrom(src);
+}
+
+void proto2_rust_Message_merge_from(google::protobuf::MessageLite* dst,
+                                    const google::protobuf::MessageLite& src) {
+  dst->CheckTypeAndMergeFrom(src);
+}
+
+}  // extern "C"
diff --git a/src/google/protobuf/compiler/rust/message.cc b/src/google/protobuf/compiler/rust/message.cc
index d2dcb534ff..ea1fddaa3a 100644
--- a/src/google/protobuf/compiler/rust/message.cc
+++ b/src/google/protobuf/compiler/rust/message.cc
@@ -61,10 +61,10 @@ void MessageNew(Context& ctx, const Descriptor& msg) {
 void MessageSerialize(Context& ctx, const Descriptor& msg) {
   switch (ctx.opts().kernel) {
     case Kernel::kCpp:
-      ctx.Emit({{"serialize_thunk", ThunkName(ctx, msg, "serialize")}}, R"rs(
+      ctx.Emit({}, R"rs(
         let mut serialized_data = $pbr$::SerializedData::new($pbi$::Private);
         let success = unsafe {
-          $serialize_thunk$(self.raw_msg(), &mut serialized_data)
+          $pbr$::proto2_rust_Message_serialize(self.raw_msg(), &mut serialized_data)
         };
         if success {
           Ok(serialized_data.into_vec())
@@ -95,9 +95,9 @@ void MessageSerialize(Context& ctx, const Descriptor& msg) {
 void MessageMutClear(Context& ctx, const Descriptor& msg) {
   switch (ctx.opts().kernel) {
     case Kernel::kCpp:
-      ctx.Emit({{"clear_thunk", ThunkName(ctx, msg, "clear")}},
+      ctx.Emit({},
                R"rs(
-          unsafe { $clear_thunk$(self.raw_msg()) }
+          unsafe { $pbr$::proto2_rust_Message_clear(self.raw_msg()) }
         )rs");
       return;
     case Kernel::kUpb:
@@ -116,11 +116,8 @@ void MessageMutClear(Context& ctx, const Descriptor& msg) {
 void MessageClearAndParse(Context& ctx, const Descriptor& msg) {
   switch (ctx.opts().kernel) {
     case Kernel::kCpp:
-      ctx.Emit(
-          {
-              {"parse_thunk", ThunkName(ctx, msg, "parse")},
-          },
-          R"rs(
+      ctx.Emit({},
+               R"rs(
           let success = unsafe {
             // SAFETY: `data.as_ptr()` is valid to read for `data.len()`.
             let data = $pbr$::SerializedData::from_raw_parts(
@@ -128,7 +125,7 @@ void MessageClearAndParse(Context& ctx, const Descriptor& msg) {
               data.len(),
             );
 
-            $parse_thunk$(self.raw_msg(), data)
+            $pbr$::proto2_rust_Message_parse(self.raw_msg(), data)
           };
           success.then_some(()).ok_or($pb$::ParseError)
         )rs");
@@ -199,12 +196,6 @@ void MessageExterns(Context& ctx, const Descriptor& msg) {
       ctx.Emit(
           {{"new_thunk", ThunkName(ctx, msg, "new")},
            {"placement_new_thunk", ThunkName(ctx, msg, "placement_new")},
-           {"delete_thunk", ThunkName(ctx, msg, "delete")},
-           {"clear_thunk", ThunkName(ctx, msg, "clear")},
-           {"serialize_thunk", ThunkName(ctx, msg, "serialize")},
-           {"parse_thunk", ThunkName(ctx, msg, "parse")},
-           {"copy_from_thunk", ThunkName(ctx, msg, "copy_from")},
-           {"merge_from_thunk", ThunkName(ctx, msg, "merge_from")},
            {"repeated_new_thunk", ThunkName(ctx, msg, "repeated_new")},
            {"repeated_free_thunk", ThunkName(ctx, msg, "repeated_free")},
            {"repeated_len_thunk", ThunkName(ctx, msg, "repeated_len")},
@@ -219,12 +210,6 @@ void MessageExterns(Context& ctx, const Descriptor& msg) {
           R"rs(
           fn $new_thunk$() -> $pbr$::RawMessage;
           fn $placement_new_thunk$(ptr: *mut std::ffi::c_void, m: $pbr$::RawMessage);
-          fn $delete_thunk$(raw_msg: $pbr$::RawMessage);
-          fn $clear_thunk$(raw_msg: $pbr$::RawMessage);
-          fn $serialize_thunk$(raw_msg: $pbr$::RawMessage, out: &mut $pbr$::SerializedData) -> bool;
-          fn $parse_thunk$(raw_msg: $pbr$::RawMessage, data: $pbr$::SerializedData) -> bool;
-          fn $copy_from_thunk$(dst: $pbr$::RawMessage, src: $pbr$::RawMessage);
-          fn $merge_from_thunk$(dst: $pbr$::RawMessage, src: $pbr$::RawMessage);
           fn $repeated_new_thunk$() -> $pbr$::RawRepeatedField;
           fn $repeated_free_thunk$(raw: $pbr$::RawRepeatedField);
           fn $repeated_len_thunk$(raw: $pbr$::RawRepeatedField) -> usize;
@@ -263,19 +248,19 @@ void MessageDrop(Context& ctx, const Descriptor& msg) {
     return;
   }
 
-  ctx.Emit({{"delete_thunk", ThunkName(ctx, msg, "delete")}}, R"rs(
-    unsafe { $delete_thunk$(self.raw_msg()); }
+  ctx.Emit({}, R"rs(
+    unsafe { $pbr$::proto2_rust_Message_delete(self.raw_msg()); }
   )rs");
 }
 
 void IntoProxiedForMessage(Context& ctx, const Descriptor& msg) {
   switch (ctx.opts().kernel) {
     case Kernel::kCpp:
-      ctx.Emit({{"copy_from_thunk", ThunkName(ctx, msg, "copy_from")}}, R"rs(
+      ctx.Emit({}, R"rs(
         impl<'msg> $pb$::IntoProxied<$Msg$> for $Msg$View<'msg> {
           fn into_proxied(self, _private: $pbi$::Private) -> $Msg$ {
             let dst = $Msg$::new();
-            unsafe { $copy_from_thunk$(dst.inner.msg, self.msg) };
+            unsafe { $pbr$::proto2_rust_Message_copy_from(dst.inner.msg, self.msg) };
             dst
           }
         }
@@ -345,16 +330,13 @@ void UpbGeneratedMessageTraitImpls(Context& ctx, const Descriptor& msg) {
 void MessageMutMergeFrom(Context& ctx, const Descriptor& msg) {
   switch (ctx.opts().kernel) {
     case Kernel::kCpp:
-      ctx.Emit(
-          {
-              {"merge_from_thunk", ThunkName(ctx, msg, "merge_from")},
-          },
-          R"rs(
+      ctx.Emit({},
+               R"rs(
           impl $pb$::MergeFrom for $Msg$Mut<'_> {
             fn merge_from(&mut self, src: impl $pb$::AsView<Proxied = $Msg$>) {
               // SAFETY: self and src are both valid `$Msg$`s.
               unsafe {
-                $merge_from_thunk$(self.raw_msg(), src.as_view().raw_msg());
+                $pbr$::proto2_rust_Message_merge_from(self.raw_msg(), src.as_view().raw_msg());
               }
             }
           }
@@ -389,7 +371,6 @@ void MessageProxiedInRepeated(Context& ctx, const Descriptor& msg) {
       ctx.Emit(
           {
               {"Msg", RsSafeName(msg.name())},
-              {"copy_from_thunk", ThunkName(ctx, msg, "copy_from")},
               {"repeated_len_thunk", ThunkName(ctx, msg, "repeated_len")},
               {"repeated_get_thunk", ThunkName(ctx, msg, "repeated_get")},
               {"repeated_get_mut_thunk",
@@ -438,7 +419,7 @@ void MessageProxiedInRepeated(Context& ctx, const Descriptor& msg) {
             // - `i < len(f)` is promised by caller.
             // - `v.raw_msg()` is a valid `const Message&`.
             unsafe {
-              $copy_from_thunk$(
+              $pbr$::proto2_rust_Message_copy_from(
                 $repeated_get_mut_thunk$(f.as_raw($pbi$::Private), i),
                 v.into_proxied($pbi$::Private).raw_msg(),
               );
@@ -467,7 +448,7 @@ void MessageProxiedInRepeated(Context& ctx, const Descriptor& msg) {
             // - `v.raw_msg()` is a valid `const Message&`.
             unsafe {
               let new_elem = $repeated_add_thunk$(f.as_raw($pbi$::Private));
-              $copy_from_thunk$(new_elem, v.into_proxied($pbi$::Private).raw_msg());
+              $pbr$::proto2_rust_Message_copy_from(new_elem, v.into_proxied($pbi$::Private).raw_msg());
             }
           }
 
@@ -1406,12 +1387,6 @@ void GenerateThunksCc(Context& ctx, const Descriptor& msg) {
        {"QualifiedMsg", cpp::QualifiedClassName(&msg)},
        {"new_thunk", ThunkName(ctx, msg, "new")},
        {"placement_new_thunk", ThunkName(ctx, msg, "placement_new")},
-       {"delete_thunk", ThunkName(ctx, msg, "delete")},
-       {"clear_thunk", ThunkName(ctx, msg, "clear")},
-       {"serialize_thunk", ThunkName(ctx, msg, "serialize")},
-       {"parse_thunk", ThunkName(ctx, msg, "parse")},
-       {"copy_from_thunk", ThunkName(ctx, msg, "copy_from")},
-       {"merge_from_thunk", ThunkName(ctx, msg, "merge_from")},
        {"repeated_new_thunk", ThunkName(ctx, msg, "repeated_new")},
        {"repeated_free_thunk", ThunkName(ctx, msg, "repeated_free")},
        {"repeated_len_thunk", ThunkName(ctx, msg, "repeated_len")},
@@ -1453,25 +1428,6 @@ void GenerateThunksCc(Context& ctx, const Descriptor& msg) {
         void $placement_new_thunk$(void* ptr, $QualifiedMsg$& m) {
           new (ptr) $QualifiedMsg$(std::move(m));
         }
-        void $delete_thunk$(void* ptr) { delete static_cast<$QualifiedMsg$*>(ptr); }
-        void $clear_thunk$(void* ptr) {
-          static_cast<$QualifiedMsg$*>(ptr)->Clear();
-        }
-        bool $serialize_thunk$($QualifiedMsg$* msg, google::protobuf::rust::SerializedData* out) {
-          return google::protobuf::rust::SerializeMsg(msg, out);
-        }
-        bool $parse_thunk$($QualifiedMsg$* msg,
-                                 google::protobuf::rust::SerializedData data) {
-          return msg->ParseFromArray(data.data, data.len);
-        }
-
-        void $copy_from_thunk$($QualifiedMsg$* dst, const $QualifiedMsg$* src) {
-          dst->CopyFrom(*src);
-        }
-
-        void $merge_from_thunk$($QualifiedMsg$* dst, const $QualifiedMsg$* src) {
-          dst->MergeFrom(*src);
-        }
 
         void* $repeated_new_thunk$() {
           return new google::protobuf::RepeatedPtrField<$QualifiedMsg$>();