From ccbed29c678be6e69d3e784da9e95ed7d2a6dc64 Mon Sep 17 00:00:00 2001 From: Adam Cozzette Date: Tue, 18 Jun 2024 10:56:44 -0700 Subject: [PATCH] Rust protobuf: make `serialize()` method return `Vec` `Vec` is a more idiomatic Rust type to return for serialization. For the C++ kernel, we are able to return this type with no extra copying. We still use `SerializedData` type for FFI, but convert the result into a `Vec` using a new `into_vec` method. The upb kernel serializes onto an arena, so for upb we do need to copy the data to get it into a `Vec`. PiperOrigin-RevId: 644444571 --- rust/cpp.rs | 22 ++++++++++++++++++++ rust/test/shared/serialization_test.rs | 2 +- rust/upb/wire.rs | 11 +++++----- src/google/protobuf/compiler/rust/message.cc | 8 +++---- 4 files changed, 32 insertions(+), 11 deletions(-) diff --git a/rust/cpp.rs b/rust/cpp.rs index 44451db1cc..4d5760b255 100644 --- a/rust/cpp.rs +++ b/rust/cpp.rs @@ -164,6 +164,28 @@ impl SerializedData { fn as_mut_ptr(&mut self) -> *mut [u8] { ptr::slice_from_raw_parts_mut(self.data.as_ptr(), self.len) } + + /// Converts into a Vec. + pub fn into_vec(self) -> Vec { + // We need to prevent self from being dropped, because we are going to transfer + // ownership of self.data to the Vec. + let s = std::mem::ManuallyDrop::new(self); + + unsafe { + // SAFETY: + // - `data` was allocated by the Rust global allocator. + // - `data` was allocated with an alignment of 1 for u8. + // - The allocated size was `len`. + // - The length and capacity are equal. + // - All `len` bytes are initialized. + // - The capacity (`len` in this case) is the size the pointer was allocated + // with. + // - The allocated size is no more than isize::MAX, because the protobuf + // serializer will refuse to serialize a message if the output would exceed + // 2^31 - 1 bytes. + Vec::::from_raw_parts(s.data.as_ptr(), s.len, s.len) + } + } } impl Deref for SerializedData { diff --git a/rust/test/shared/serialization_test.rs b/rust/test/shared/serialization_test.rs index ab9dd21ca1..e5b614034f 100644 --- a/rust/test/shared/serialization_test.rs +++ b/rust/test/shared/serialization_test.rs @@ -61,7 +61,7 @@ macro_rules! generate_parameterized_serialization_test { msg.set_optional_bool(true); let mut msg2 = [< $type >]::new(); msg2.set_optional_bytes(msg.serialize().unwrap()); - assert_that!(msg2.optional_bytes(), eq(msg.serialize().unwrap().as_ref())); + assert_that!(msg2.optional_bytes(), eq(msg.serialize().unwrap())); } #[test] diff --git a/rust/upb/wire.rs b/rust/upb/wire.rs index 0b9be44244..7660c84e9c 100644 --- a/rust/upb/wire.rs +++ b/rust/upb/wire.rs @@ -1,4 +1,4 @@ -use crate::{upb_ExtensionRegistry, upb_MiniTable, Arena, OwnedArenaBox, RawArena, RawMessage}; +use crate::{upb_ExtensionRegistry, upb_MiniTable, Arena, RawArena, RawMessage}; use std::ptr::NonNull; // LINT.IfChange(encode_status) @@ -37,12 +37,12 @@ enum DecodeOption { /// If Err, then EncodeStatus != Ok. /// -/// SAFETY: +/// # Safety /// - `msg` must be associated with `mini_table`. pub unsafe fn encode( msg: RawMessage, mini_table: *const upb_MiniTable, -) -> Result, EncodeStatus> { +) -> Result, EncodeStatus> { let arena = Arena::new(); let mut buf: *mut u8 = std::ptr::null_mut(); let mut len = 0usize; @@ -55,8 +55,7 @@ pub unsafe fn encode( if status == EncodeStatus::Ok { assert!(!buf.is_null()); // EncodeStatus Ok should never return NULL data, even for len=0. // SAFETY: upb guarantees that `buf` is valid to read for `len`. - let slice = NonNull::new_unchecked(std::ptr::slice_from_raw_parts_mut(buf, len)); - Ok(OwnedArenaBox::new(slice, arena)) + Ok((*std::ptr::slice_from_raw_parts(buf, len)).to_vec()) } else { Err(status) } @@ -65,7 +64,7 @@ pub unsafe fn encode( /// Decodes into the provided message (merge semantics). If Err, then /// DecodeStatus != Ok. /// -/// SAFETY: +/// # Safety /// - `msg` must be mutable. /// - `msg` must be associated with `mini_table`. pub unsafe fn decode( diff --git a/src/google/protobuf/compiler/rust/message.cc b/src/google/protobuf/compiler/rust/message.cc index 5fb4d01032..68946b795a 100644 --- a/src/google/protobuf/compiler/rust/message.cc +++ b/src/google/protobuf/compiler/rust/message.cc @@ -66,7 +66,7 @@ void MessageSerialize(Context& ctx, const Descriptor& msg) { $serialize_thunk$(self.raw_msg(), &mut serialized_data) }; if success { - Ok(serialized_data) + Ok(serialized_data.into_vec()) } else { Err($pb$::SerializeError) } @@ -939,7 +939,7 @@ void GenerateRs(Context& ctx, const Descriptor& msg) { self.msg } - pub fn serialize(&self) -> Result<$pbr$::SerializedData, $pb$::SerializeError> { + pub fn serialize(&self) -> Result, $pb$::SerializeError> { $Msg::serialize$ } @@ -1015,7 +1015,7 @@ void GenerateRs(Context& ctx, const Descriptor& msg) { self.inner } - pub fn serialize(&self) -> Result<$pbr$::SerializedData, $pb$::SerializeError> { + pub fn serialize(&self) -> Result, $pb$::SerializeError> { $pb$::ViewProxy::as_view(self).serialize() } @@ -1069,7 +1069,7 @@ void GenerateRs(Context& ctx, const Descriptor& msg) { $raw_arena_getter_for_message$ - pub fn serialize(&self) -> Result<$pbr$::SerializedData, $pb$::SerializeError> { + pub fn serialize(&self) -> Result, $pb$::SerializeError> { self.as_view().serialize() } #[deprecated = "Prefer Msg::parse(), or use the new name 'clear_and_parse' to parse into a pre-existing message."]