Copy repeated string/bytes in upb when pushing/setting/copying

This memory management should be handled by Rust.

I've confirmed this works by running the new included tests with msan.
The sanitizer is necessary to detect an incorrect copy_from impl
that uses-after-free from the upb arena.

PiperOrigin-RevId: 604689154
pull/15741/head
Alyssa Haroldsen 1 year ago committed by Copybara-Service
parent e8535e70da
commit 3ccccdb855
  1. 61
      rust/test/shared/accessors_repeated_test.rs
  2. 115
      rust/upb.rs

@ -179,11 +179,10 @@ fn test_repeated_message() {
msg2.repeated_nested_message_mut().copy_from(msg.repeated_nested_message()); msg2.repeated_nested_message_mut().copy_from(msg.repeated_nested_message());
assert_that!(msg2.repeated_nested_message().get(0).unwrap().bb(), eq(1)); assert_that!(msg2.repeated_nested_message().get(0).unwrap().bb(), eq(1));
msg2.repeated_nested_message_mut().clear();
assert_that!(msg2.repeated_nested_message().len(), eq(0));
let mut nested2 = NestedMessage::new(); let mut nested2 = NestedMessage::new();
nested2.bb_mut().set(2); nested2.bb_mut().set(2);
// TODO: b/320936046 - Test SettableValue once available
msg.repeated_nested_message_mut().set(0, nested2.as_view()); msg.repeated_nested_message_mut().set(0, nested2.as_view());
assert_that!(msg.repeated_nested_message().get(0).unwrap().bb(), eq(2)); assert_that!(msg.repeated_nested_message().get(0).unwrap().bb(), eq(2));
@ -191,6 +190,12 @@ fn test_repeated_message() {
msg.repeated_nested_message().iter().map(|m| m.bb()).collect::<Vec<_>>(), msg.repeated_nested_message().iter().map(|m| m.bb()).collect::<Vec<_>>(),
eq(vec![2]), eq(vec![2]),
); );
drop(msg);
assert_that!(msg2.repeated_nested_message().get(0).unwrap().bb(), eq(1));
msg2.repeated_nested_message_mut().clear();
assert_that!(msg2.repeated_nested_message().len(), eq(0));
} }
#[test] #[test]
@ -200,17 +205,31 @@ fn test_repeated_strings() {
let mut msg = TestAllTypes::new(); let mut msg = TestAllTypes::new();
assert_that!(msg.repeated_string(), empty()); assert_that!(msg.repeated_string(), empty());
{ {
let s = String::from("set from Mut");
// TODO: b/320936046 - Test SettableValue once available
msg.repeated_string_mut().push(s.as_str().into());
}
msg.repeated_string_mut().push("second str".into());
{
let s2 = String::from("set second str");
// TODO: b/320936046 - Test SettableValue once available // TODO: b/320936046 - Test SettableValue once available
msg.repeated_string_mut().push("set from Mut".into()); msg.repeated_string_mut().set(1, s2.as_str().into());
} }
assert_that!(msg.repeated_string().len(), eq(1)); assert_that!(msg.repeated_string().len(), eq(2));
// TODO: b/320932827 - Use elements_are! when ready
assert_that!(msg.repeated_string().get(0).unwrap(), eq("set from Mut")); assert_that!(msg.repeated_string().get(0).unwrap(), eq("set from Mut"));
assert_that!(msg.repeated_string().get(1).unwrap(), eq("set second str"));
assert_that!(
msg.repeated_string().iter().collect::<Vec<_>>(),
elements_are![eq("set from Mut"), eq("set second str")]
);
older_msg.repeated_string_mut().copy_from(msg.repeated_string()); older_msg.repeated_string_mut().copy_from(msg.repeated_string());
} }
// TODO: b/320932827 - Use elements_are! when ready assert_that!(older_msg.repeated_string().len(), eq(2));
assert_that!(older_msg.repeated_string().len(), eq(1)); assert_that!(
older_msg.repeated_string().iter().collect::<Vec<_>>(),
elements_are![eq("set from Mut"), eq("set second str")]
);
older_msg.repeated_string_mut().clear(); older_msg.repeated_string_mut().clear();
assert_that!(older_msg.repeated_string(), empty()); assert_that!(older_msg.repeated_string(), empty());
@ -223,17 +242,33 @@ fn test_repeated_bytes() {
let mut msg = TestAllTypes::new(); let mut msg = TestAllTypes::new();
assert_that!(msg.repeated_bytes(), empty()); assert_that!(msg.repeated_bytes(), empty());
{ {
let s = Vec::from(b"set from Mut");
// TODO: b/320936046 - Test SettableValue once available
msg.repeated_bytes_mut().push(&s[..]);
}
msg.repeated_bytes_mut().push(b"second bytes");
{
let s2 = Vec::from(b"set second bytes");
// TODO: b/320936046 - Test SettableValue once available // TODO: b/320936046 - Test SettableValue once available
msg.repeated_bytes_mut().push(b"set from Mut"); msg.repeated_bytes_mut().set(1, &s2[..]);
} }
assert_that!(msg.repeated_bytes().len(), eq(1)); assert_that!(msg.repeated_bytes().len(), eq(2));
// TODO: b/320932827 - Use elements_are! when ready
assert_that!(msg.repeated_bytes().get(0).unwrap(), eq(b"set from Mut")); assert_that!(msg.repeated_bytes().get(0).unwrap(), eq(b"set from Mut"));
assert_that!(msg.repeated_bytes().get(1).unwrap(), eq(b"set second bytes"));
assert_that!(
msg.repeated_bytes().iter().collect::<Vec<_>>(),
elements_are![eq(b"set from Mut"), eq(b"set second bytes")]
);
older_msg.repeated_bytes_mut().copy_from(msg.repeated_bytes()); older_msg.repeated_bytes_mut().copy_from(msg.repeated_bytes());
} }
// TODO: b/320932827 - Use elements_are! when ready assert_that!(older_msg.repeated_bytes().len(), eq(2));
assert_that!(older_msg.repeated_bytes().len(), eq(1)); assert_that!(older_msg.repeated_bytes().get(0).unwrap(), eq(b"set from Mut"));
assert_that!(older_msg.repeated_bytes().get(1).unwrap(), eq(b"set second bytes"));
assert_that!(
older_msg.repeated_bytes().iter().collect::<Vec<_>>(),
elements_are![eq(b"set from Mut"), eq(b"set second bytes")]
);
older_msg.repeated_bytes_mut().clear(); older_msg.repeated_bytes_mut().clear();
assert_that!(older_msg.repeated_bytes(), empty()); assert_that!(older_msg.repeated_bytes(), empty());

@ -19,7 +19,7 @@ use std::cell::UnsafeCell;
use std::ffi::c_int; use std::ffi::c_int;
use std::fmt; use std::fmt;
use std::marker::PhantomData; use std::marker::PhantomData;
use std::mem::{size_of, MaybeUninit}; use std::mem::{size_of, ManuallyDrop, MaybeUninit};
use std::ops::Deref; use std::ops::Deref;
use std::ptr::{self, NonNull}; use std::ptr::{self, NonNull};
use std::slice; use std::slice;
@ -451,15 +451,12 @@ extern "C" {
pub fn upb_Array_GetMutable(arr: RawRepeatedField, i: usize) -> upb_MutableMessageValue; pub fn upb_Array_GetMutable(arr: RawRepeatedField, i: usize) -> upb_MutableMessageValue;
} }
macro_rules! impl_repeated_primitives { macro_rules! impl_repeated_base {
($(($t:ty, $elem_t:ty, $ufield:ident, $upb_tag:expr)),* $(,)?) => { ($t:ty, $elem_t:ty, $ufield:ident, $upb_tag:expr) => {
$(
unsafe impl ProxiedInRepeated for $t {
#[allow(dead_code)] #[allow(dead_code)]
fn repeated_new(_: Private) -> Repeated<$t> { fn repeated_new(_: Private) -> Repeated<$t> {
let arena = Arena::new(); let arena = ManuallyDrop::new(Arena::new());
let raw_arena = arena.raw(); let raw_arena = arena.raw();
std::mem::forget(arena);
unsafe { unsafe {
Repeated::from_inner(InnerRepeatedMut { Repeated::from_inner(InnerRepeatedMut {
raw: upb_Array_New(raw_arena, $upb_tag as c_int), raw: upb_Array_New(raw_arena, $upb_tag as c_int),
@ -482,36 +479,52 @@ macro_rules! impl_repeated_primitives {
unsafe { upb_Array_Size(f.as_raw(Private)) } unsafe { upb_Array_Size(f.as_raw(Private)) }
} }
fn repeated_push(mut f: Mut<Repeated<$t>>, v: View<$t>) { fn repeated_push(mut f: Mut<Repeated<$t>>, v: View<$t>) {
let arena = f.raw_arena(Private);
unsafe { unsafe {
upb_Array_Append( upb_Array_Append(
f.as_raw(Private), f.as_raw(Private),
<$t as UpbTypeConversions>::to_message_value(v), <$t as UpbTypeConversions>::to_message_value_copy_if_required(arena, v),
f.raw_arena(Private)) arena,
)
} }
} }
fn repeated_clear(mut f: Mut<Repeated<$t>>) { fn repeated_clear(mut f: Mut<Repeated<$t>>) {
unsafe { upb_Array_Resize(f.as_raw(Private), 0, f.raw_arena(Private)); } unsafe {
upb_Array_Resize(f.as_raw(Private), 0, f.raw_arena(Private));
}
} }
unsafe fn repeated_get_unchecked(f: View<Repeated<$t>>, i: usize) -> View<$t> { unsafe fn repeated_get_unchecked(f: View<Repeated<$t>>, i: usize) -> View<$t> {
unsafe { unsafe {
<$t as UpbTypeConversions>::from_message_value( <$t as UpbTypeConversions>::from_message_value(upb_Array_Get(f.as_raw(Private), i))
upb_Array_Get(f.as_raw(Private), i)) } }
} }
unsafe fn repeated_set_unchecked(mut f: Mut<Repeated<$t>>, i: usize, v: View<$t>) { unsafe fn repeated_set_unchecked(mut f: Mut<Repeated<$t>>, i: usize, v: View<$t>) {
let arena = f.raw_arena(Private);
unsafe { unsafe {
upb_Array_Set( upb_Array_Set(
f.as_raw(Private), f.as_raw(Private),
i, i,
<$t as UpbTypeConversions>::to_message_value(v.into())) <$t as UpbTypeConversions>::to_message_value_copy_if_required(arena, v.into()),
)
}
} }
};
} }
macro_rules! impl_repeated_primitives {
($(($t:ty, $elem_t:ty, $ufield:ident, $upb_tag:expr)),* $(,)?) => {
$(
unsafe impl ProxiedInRepeated for $t {
impl_repeated_base!($t, $elem_t, $ufield, $upb_tag);
fn repeated_copy_from(src: View<Repeated<$t>>, mut dest: Mut<Repeated<$t>>) { fn repeated_copy_from(src: View<Repeated<$t>>, mut dest: Mut<Repeated<$t>>) {
let arena = dest.raw_arena(Private);
// SAFETY: // SAFETY:
// - `upb_Array_Resize` is unsafe but assumed to be always sound to call. // - `upb_Array_Resize` is unsafe but assumed to be always sound to call.
// - `copy_nonoverlapping` is unsafe but here we guarantee that both pointers // - `copy_nonoverlapping` is unsafe but here we guarantee that both pointers
// are valid, the pointers are `#[repr(u8)]`, and the size is correct. // are valid, the pointers are `#[repr(u8)]`, and the size is correct.
unsafe { unsafe {
if (!upb_Array_Resize(dest.as_raw(Private), src.len(), dest.raw_arena(Private))) { if (!upb_Array_Resize(dest.as_raw(Private), src.len(), arena)) {
panic!("upb_Array_Resize failed."); panic!("upb_Array_Resize failed.");
} }
ptr::copy_nonoverlapping( ptr::copy_nonoverlapping(
@ -525,6 +538,45 @@ macro_rules! impl_repeated_primitives {
} }
} }
macro_rules! impl_repeated_bytes {
($(($t:ty, $upb_tag:expr)),* $(,)?) => {
$(
unsafe impl ProxiedInRepeated for $t {
impl_repeated_base!($t, PtrAndLen, str_val, $upb_tag);
fn repeated_copy_from(src: View<Repeated<$t>>, mut dest: Mut<Repeated<$t>>) {
let len = src.len();
// SAFETY:
// - `upb_Array_Resize` is unsafe but assumed to be always sound to call.
// - `upb_Array` ensures its elements are never uninitialized memory.
// - The `DataPtr` and `MutableDataPtr` functions return pointers to spans
// of memory that are valid for at least `len` elements of PtrAndLen.
// - `copy_nonoverlapping` is unsafe but here we guarantee that both pointers
// are valid, the pointers are `#[repr(u8)]`, and the size is correct.
// - The bytes held within a valid array are valid.
unsafe {
let arena = ManuallyDrop::new(Arena::from_raw(dest.raw_arena(Private)));
if (!upb_Array_Resize(dest.as_raw(Private), src.len(), arena.raw())) {
panic!("upb_Array_Resize failed.");
}
let src_ptrs: &[PtrAndLen] = slice::from_raw_parts(
upb_Array_DataPtr(src.as_raw(Private)).cast(),
len
);
let dest_ptrs: &mut [PtrAndLen] = slice::from_raw_parts_mut(
upb_Array_MutableDataPtr(dest.as_raw(Private)).cast(),
len
);
for (src_ptr, dest_ptr) in src_ptrs.iter().zip(dest_ptrs) {
*dest_ptr = copy_bytes_in_arena(&arena, src_ptr.as_ref()).into();
}
}
}
}
)*
}
}
impl<'msg, T: ?Sized> RepeatedMut<'msg, T> { impl<'msg, T: ?Sized> RepeatedMut<'msg, T> {
// Returns a `RawArena` which is live for at least `'msg` // Returns a `RawArena` which is live for at least `'msg`
#[doc(hidden)] #[doc(hidden)]
@ -542,10 +594,10 @@ impl_repeated_primitives!(
(u32, u32, uint32_val, UpbCType::UInt32), (u32, u32, uint32_val, UpbCType::UInt32),
(i64, i64, int64_val, UpbCType::Int64), (i64, i64, int64_val, UpbCType::Int64),
(u64, u64, uint64_val, UpbCType::UInt64), (u64, u64, uint64_val, UpbCType::UInt64),
(ProtoStr, PtrAndLen, str_val, UpbCType::String),
([u8], PtrAndLen, str_val, UpbCType::Bytes),
); );
impl_repeated_bytes!((ProtoStr, UpbCType::String), ([u8], UpbCType::Bytes),);
/// Copy the contents of `src` into `dest`. /// Copy the contents of `src` into `dest`.
/// ///
/// # Safety /// # Safety
@ -692,22 +744,27 @@ macro_rules! impl_upb_type_conversions_for_scalars {
($($t:ty, $ufield:ident, $upb_tag:expr, $zero_val:literal;)*) => { ($($t:ty, $ufield:ident, $upb_tag:expr, $zero_val:literal;)*) => {
$( $(
impl UpbTypeConversions for $t { impl UpbTypeConversions for $t {
#[inline(always)]
fn upb_type() -> UpbCType { fn upb_type() -> UpbCType {
$upb_tag $upb_tag
} }
#[inline(always)]
fn to_message_value(val: View<'_, $t>) -> upb_MessageValue { fn to_message_value(val: View<'_, $t>) -> upb_MessageValue {
upb_MessageValue { $ufield: val } upb_MessageValue { $ufield: val }
} }
#[inline(always)]
fn empty_message_value() -> upb_MessageValue { fn empty_message_value() -> upb_MessageValue {
Self::to_message_value($zero_val) Self::to_message_value($zero_val)
} }
#[inline(always)]
unsafe fn to_message_value_copy_if_required(_: RawArena, val: View<'_, $t>) -> upb_MessageValue { unsafe fn to_message_value_copy_if_required(_: RawArena, val: View<'_, $t>) -> upb_MessageValue {
Self::to_message_value(val) Self::to_message_value(val)
} }
#[inline(always)]
unsafe fn from_message_value<'msg>(msg: upb_MessageValue) -> View<'msg, $t> { unsafe fn from_message_value<'msg>(msg: upb_MessageValue) -> View<'msg, $t> {
unsafe { msg.$ufield } unsafe { msg.$ufield }
} }
@ -743,13 +800,10 @@ impl UpbTypeConversions for [u8] {
raw_arena: RawArena, raw_arena: RawArena,
val: View<'_, [u8]>, val: View<'_, [u8]>,
) -> upb_MessageValue { ) -> upb_MessageValue {
// SAFETY: // SAFETY: The arena memory is not freed due to `ManuallyDrop`.
// The arena memory is not freed because we prevent its destructor from let arena = ManuallyDrop::new(unsafe { Arena::from_raw(raw_arena) });
// executing with the call to `std::mem::forget(arena)`.
let arena = unsafe { Arena::from_raw(raw_arena) };
let copied = copy_bytes_in_arena(&arena, val); let copied = copy_bytes_in_arena(&arena, val);
let msg_val = Self::to_message_value(copied); let msg_val = Self::to_message_value(copied);
std::mem::forget(arena);
msg_val msg_val
} }
@ -775,18 +829,13 @@ impl UpbTypeConversions for ProtoStr {
raw_arena: RawArena, raw_arena: RawArena,
val: View<'_, ProtoStr>, val: View<'_, ProtoStr>,
) -> upb_MessageValue { ) -> upb_MessageValue {
// SAFETY: // SAFETY: `raw_arena` is valid as promised by the caller
// The arena memory is not freed because we prevent its destructor from unsafe {
// executing with the call to `std::mem::forget(arena)`. <[u8] as UpbTypeConversions>::to_message_value_copy_if_required(
let arena = unsafe { Arena::from_raw(raw_arena) }; raw_arena,
let copied = copy_bytes_in_arena(&arena, val.into()); val.as_bytes(),
)
// SAFETY: }
// `val` is a valid `ProtoStr` and `copied` is an exact copy of `val`.
let proto_str = unsafe { ProtoStr::from_utf8_unchecked(copied) };
let msg_val = Self::to_message_value(proto_str);
std::mem::forget(arena);
msg_val
} }
unsafe fn from_message_value<'msg>(msg: upb_MessageValue) -> View<'msg, ProtoStr> { unsafe fn from_message_value<'msg>(msg: upb_MessageValue) -> View<'msg, ProtoStr> {

Loading…
Cancel
Save