Copy repeated string/bytes in upb when pushing/setting/copying

This memory management should be handled by Rust.

I've confirmed this works by running the new included tests with msan.
The sanitizer is necessary to detect an incorrect copy_from impl
that uses-after-free from the upb arena.

PiperOrigin-RevId: 604689154
pull/15741/head
Alyssa Haroldsen 1 year ago committed by Copybara-Service
parent e8535e70da
commit 3ccccdb855
  1. 61
      rust/test/shared/accessors_repeated_test.rs
  2. 193
      rust/upb.rs

@ -179,11 +179,10 @@ fn test_repeated_message() {
msg2.repeated_nested_message_mut().copy_from(msg.repeated_nested_message());
assert_that!(msg2.repeated_nested_message().get(0).unwrap().bb(), eq(1));
msg2.repeated_nested_message_mut().clear();
assert_that!(msg2.repeated_nested_message().len(), eq(0));
let mut nested2 = NestedMessage::new();
nested2.bb_mut().set(2);
// TODO: b/320936046 - Test SettableValue once available
msg.repeated_nested_message_mut().set(0, nested2.as_view());
assert_that!(msg.repeated_nested_message().get(0).unwrap().bb(), eq(2));
@ -191,6 +190,12 @@ fn test_repeated_message() {
msg.repeated_nested_message().iter().map(|m| m.bb()).collect::<Vec<_>>(),
eq(vec![2]),
);
drop(msg);
assert_that!(msg2.repeated_nested_message().get(0).unwrap().bb(), eq(1));
msg2.repeated_nested_message_mut().clear();
assert_that!(msg2.repeated_nested_message().len(), eq(0));
}
#[test]
@ -200,17 +205,31 @@ fn test_repeated_strings() {
let mut msg = TestAllTypes::new();
assert_that!(msg.repeated_string(), empty());
{
let s = String::from("set from Mut");
// TODO: b/320936046 - Test SettableValue once available
msg.repeated_string_mut().push("set from Mut".into());
msg.repeated_string_mut().push(s.as_str().into());
}
assert_that!(msg.repeated_string().len(), eq(1));
// TODO: b/320932827 - Use elements_are! when ready
msg.repeated_string_mut().push("second str".into());
{
let s2 = String::from("set second str");
// TODO: b/320936046 - Test SettableValue once available
msg.repeated_string_mut().set(1, s2.as_str().into());
}
assert_that!(msg.repeated_string().len(), eq(2));
assert_that!(msg.repeated_string().get(0).unwrap(), eq("set from Mut"));
assert_that!(msg.repeated_string().get(1).unwrap(), eq("set second str"));
assert_that!(
msg.repeated_string().iter().collect::<Vec<_>>(),
elements_are![eq("set from Mut"), eq("set second str")]
);
older_msg.repeated_string_mut().copy_from(msg.repeated_string());
}
// TODO: b/320932827 - Use elements_are! when ready
assert_that!(older_msg.repeated_string().len(), eq(1));
assert_that!(older_msg.repeated_string().len(), eq(2));
assert_that!(
older_msg.repeated_string().iter().collect::<Vec<_>>(),
elements_are![eq("set from Mut"), eq("set second str")]
);
older_msg.repeated_string_mut().clear();
assert_that!(older_msg.repeated_string(), empty());
@ -223,17 +242,33 @@ fn test_repeated_bytes() {
let mut msg = TestAllTypes::new();
assert_that!(msg.repeated_bytes(), empty());
{
let s = Vec::from(b"set from Mut");
// TODO: b/320936046 - Test SettableValue once available
msg.repeated_bytes_mut().push(b"set from Mut");
msg.repeated_bytes_mut().push(&s[..]);
}
assert_that!(msg.repeated_bytes().len(), eq(1));
// TODO: b/320932827 - Use elements_are! when ready
msg.repeated_bytes_mut().push(b"second bytes");
{
let s2 = Vec::from(b"set second bytes");
// TODO: b/320936046 - Test SettableValue once available
msg.repeated_bytes_mut().set(1, &s2[..]);
}
assert_that!(msg.repeated_bytes().len(), eq(2));
assert_that!(msg.repeated_bytes().get(0).unwrap(), eq(b"set from Mut"));
assert_that!(msg.repeated_bytes().get(1).unwrap(), eq(b"set second bytes"));
assert_that!(
msg.repeated_bytes().iter().collect::<Vec<_>>(),
elements_are![eq(b"set from Mut"), eq(b"set second bytes")]
);
older_msg.repeated_bytes_mut().copy_from(msg.repeated_bytes());
}
// TODO: b/320932827 - Use elements_are! when ready
assert_that!(older_msg.repeated_bytes().len(), eq(1));
assert_that!(older_msg.repeated_bytes().len(), eq(2));
assert_that!(older_msg.repeated_bytes().get(0).unwrap(), eq(b"set from Mut"));
assert_that!(older_msg.repeated_bytes().get(1).unwrap(), eq(b"set second bytes"));
assert_that!(
older_msg.repeated_bytes().iter().collect::<Vec<_>>(),
elements_are![eq(b"set from Mut"), eq(b"set second bytes")]
);
older_msg.repeated_bytes_mut().clear();
assert_that!(older_msg.repeated_bytes(), empty());

@ -19,7 +19,7 @@ use std::cell::UnsafeCell;
use std::ffi::c_int;
use std::fmt;
use std::marker::PhantomData;
use std::mem::{size_of, MaybeUninit};
use std::mem::{size_of, ManuallyDrop, MaybeUninit};
use std::ops::Deref;
use std::ptr::{self, NonNull};
use std::slice;
@ -451,67 +451,80 @@ extern "C" {
pub fn upb_Array_GetMutable(arr: RawRepeatedField, i: usize) -> upb_MutableMessageValue;
}
macro_rules! impl_repeated_base {
($t:ty, $elem_t:ty, $ufield:ident, $upb_tag:expr) => {
#[allow(dead_code)]
fn repeated_new(_: Private) -> Repeated<$t> {
let arena = ManuallyDrop::new(Arena::new());
let raw_arena = arena.raw();
unsafe {
Repeated::from_inner(InnerRepeatedMut {
raw: upb_Array_New(raw_arena, $upb_tag as c_int),
arena: raw_arena,
_phantom: PhantomData,
})
}
}
#[allow(dead_code)]
unsafe fn repeated_free(_: Private, f: &mut Repeated<$t>) {
// Freeing the array itself is handled by `Arena::Drop`
// SAFETY:
// - `f.raw_arena()` is a live `upb_Arena*` as
// - This function is only called once for `f`
unsafe {
upb_Arena_Free(f.inner().arena);
}
}
fn repeated_len(f: View<Repeated<$t>>) -> usize {
unsafe { upb_Array_Size(f.as_raw(Private)) }
}
fn repeated_push(mut f: Mut<Repeated<$t>>, v: View<$t>) {
let arena = f.raw_arena(Private);
unsafe {
upb_Array_Append(
f.as_raw(Private),
<$t as UpbTypeConversions>::to_message_value_copy_if_required(arena, v),
arena,
)
}
}
fn repeated_clear(mut f: Mut<Repeated<$t>>) {
unsafe {
upb_Array_Resize(f.as_raw(Private), 0, f.raw_arena(Private));
}
}
unsafe fn repeated_get_unchecked(f: View<Repeated<$t>>, i: usize) -> View<$t> {
unsafe {
<$t as UpbTypeConversions>::from_message_value(upb_Array_Get(f.as_raw(Private), i))
}
}
unsafe fn repeated_set_unchecked(mut f: Mut<Repeated<$t>>, i: usize, v: View<$t>) {
let arena = f.raw_arena(Private);
unsafe {
upb_Array_Set(
f.as_raw(Private),
i,
<$t as UpbTypeConversions>::to_message_value_copy_if_required(arena, v.into()),
)
}
}
};
}
macro_rules! impl_repeated_primitives {
($(($t:ty, $elem_t:ty, $ufield:ident, $upb_tag:expr)),* $(,)?) => {
$(
unsafe impl ProxiedInRepeated for $t {
#[allow(dead_code)]
fn repeated_new(_: Private) -> Repeated<$t> {
let arena = Arena::new();
let raw_arena = arena.raw();
std::mem::forget(arena);
unsafe {
Repeated::from_inner(InnerRepeatedMut {
raw: upb_Array_New(raw_arena, $upb_tag as c_int),
arena: raw_arena,
_phantom: PhantomData,
})
}
}
#[allow(dead_code)]
unsafe fn repeated_free(_: Private, f: &mut Repeated<$t>) {
// Freeing the array itself is handled by `Arena::Drop`
// SAFETY:
// - `f.raw_arena()` is a live `upb_Arena*` as
// - This function is only called once for `f`
unsafe {
upb_Arena_Free(f.inner().arena);
}
}
fn repeated_len(f: View<Repeated<$t>>) -> usize {
unsafe { upb_Array_Size(f.as_raw(Private)) }
}
fn repeated_push(mut f: Mut<Repeated<$t>>, v: View<$t>) {
unsafe {
upb_Array_Append(
f.as_raw(Private),
<$t as UpbTypeConversions>::to_message_value(v),
f.raw_arena(Private))
}
}
fn repeated_clear(mut f: Mut<Repeated<$t>>) {
unsafe { upb_Array_Resize(f.as_raw(Private), 0, f.raw_arena(Private)); }
}
unsafe fn repeated_get_unchecked(f: View<Repeated<$t>>, i: usize) -> View<$t> {
unsafe {
<$t as UpbTypeConversions>::from_message_value(
upb_Array_Get(f.as_raw(Private), i)) }
}
unsafe fn repeated_set_unchecked(mut f: Mut<Repeated<$t>>, i: usize, v: View<$t>) {
unsafe {
upb_Array_Set(
f.as_raw(Private),
i,
<$t as UpbTypeConversions>::to_message_value(v.into()))
}
}
impl_repeated_base!($t, $elem_t, $ufield, $upb_tag);
fn repeated_copy_from(src: View<Repeated<$t>>, mut dest: Mut<Repeated<$t>>) {
let arena = dest.raw_arena(Private);
// SAFETY:
// - `upb_Array_Resize` is unsafe but assumed to be always sound to call.
// - `copy_nonoverlapping` is unsafe but here we guarantee that both pointers
// are valid, the pointers are `#[repr(u8)]`, and the size is correct.
unsafe {
if (!upb_Array_Resize(dest.as_raw(Private), src.len(), dest.raw_arena(Private))) {
if (!upb_Array_Resize(dest.as_raw(Private), src.len(), arena)) {
panic!("upb_Array_Resize failed.");
}
ptr::copy_nonoverlapping(
@ -525,6 +538,45 @@ macro_rules! impl_repeated_primitives {
}
}
macro_rules! impl_repeated_bytes {
($(($t:ty, $upb_tag:expr)),* $(,)?) => {
$(
unsafe impl ProxiedInRepeated for $t {
impl_repeated_base!($t, PtrAndLen, str_val, $upb_tag);
fn repeated_copy_from(src: View<Repeated<$t>>, mut dest: Mut<Repeated<$t>>) {
let len = src.len();
// SAFETY:
// - `upb_Array_Resize` is unsafe but assumed to be always sound to call.
// - `upb_Array` ensures its elements are never uninitialized memory.
// - The `DataPtr` and `MutableDataPtr` functions return pointers to spans
// of memory that are valid for at least `len` elements of PtrAndLen.
// - `copy_nonoverlapping` is unsafe but here we guarantee that both pointers
// are valid, the pointers are `#[repr(u8)]`, and the size is correct.
// - The bytes held within a valid array are valid.
unsafe {
let arena = ManuallyDrop::new(Arena::from_raw(dest.raw_arena(Private)));
if (!upb_Array_Resize(dest.as_raw(Private), src.len(), arena.raw())) {
panic!("upb_Array_Resize failed.");
}
let src_ptrs: &[PtrAndLen] = slice::from_raw_parts(
upb_Array_DataPtr(src.as_raw(Private)).cast(),
len
);
let dest_ptrs: &mut [PtrAndLen] = slice::from_raw_parts_mut(
upb_Array_MutableDataPtr(dest.as_raw(Private)).cast(),
len
);
for (src_ptr, dest_ptr) in src_ptrs.iter().zip(dest_ptrs) {
*dest_ptr = copy_bytes_in_arena(&arena, src_ptr.as_ref()).into();
}
}
}
}
)*
}
}
impl<'msg, T: ?Sized> RepeatedMut<'msg, T> {
// Returns a `RawArena` which is live for at least `'msg`
#[doc(hidden)]
@ -542,10 +594,10 @@ impl_repeated_primitives!(
(u32, u32, uint32_val, UpbCType::UInt32),
(i64, i64, int64_val, UpbCType::Int64),
(u64, u64, uint64_val, UpbCType::UInt64),
(ProtoStr, PtrAndLen, str_val, UpbCType::String),
([u8], PtrAndLen, str_val, UpbCType::Bytes),
);
impl_repeated_bytes!((ProtoStr, UpbCType::String), ([u8], UpbCType::Bytes),);
/// Copy the contents of `src` into `dest`.
///
/// # Safety
@ -692,22 +744,27 @@ macro_rules! impl_upb_type_conversions_for_scalars {
($($t:ty, $ufield:ident, $upb_tag:expr, $zero_val:literal;)*) => {
$(
impl UpbTypeConversions for $t {
#[inline(always)]
fn upb_type() -> UpbCType {
$upb_tag
}
#[inline(always)]
fn to_message_value(val: View<'_, $t>) -> upb_MessageValue {
upb_MessageValue { $ufield: val }
}
#[inline(always)]
fn empty_message_value() -> upb_MessageValue {
Self::to_message_value($zero_val)
}
unsafe fn to_message_value_copy_if_required(_ : RawArena, val: View<'_, $t>) -> upb_MessageValue {
#[inline(always)]
unsafe fn to_message_value_copy_if_required(_: RawArena, val: View<'_, $t>) -> upb_MessageValue {
Self::to_message_value(val)
}
#[inline(always)]
unsafe fn from_message_value<'msg>(msg: upb_MessageValue) -> View<'msg, $t> {
unsafe { msg.$ufield }
}
@ -743,13 +800,10 @@ impl UpbTypeConversions for [u8] {
raw_arena: RawArena,
val: View<'_, [u8]>,
) -> upb_MessageValue {
// SAFETY:
// The arena memory is not freed because we prevent its destructor from
// executing with the call to `std::mem::forget(arena)`.
let arena = unsafe { Arena::from_raw(raw_arena) };
// SAFETY: The arena memory is not freed due to `ManuallyDrop`.
let arena = ManuallyDrop::new(unsafe { Arena::from_raw(raw_arena) });
let copied = copy_bytes_in_arena(&arena, val);
let msg_val = Self::to_message_value(copied);
std::mem::forget(arena);
msg_val
}
@ -775,18 +829,13 @@ impl UpbTypeConversions for ProtoStr {
raw_arena: RawArena,
val: View<'_, ProtoStr>,
) -> upb_MessageValue {
// SAFETY:
// The arena memory is not freed because we prevent its destructor from
// executing with the call to `std::mem::forget(arena)`.
let arena = unsafe { Arena::from_raw(raw_arena) };
let copied = copy_bytes_in_arena(&arena, val.into());
// SAFETY:
// `val` is a valid `ProtoStr` and `copied` is an exact copy of `val`.
let proto_str = unsafe { ProtoStr::from_utf8_unchecked(copied) };
let msg_val = Self::to_message_value(proto_str);
std::mem::forget(arena);
msg_val
// SAFETY: `raw_arena` is valid as promised by the caller
unsafe {
<[u8] as UpbTypeConversions>::to_message_value_copy_if_required(
raw_arena,
val.as_bytes(),
)
}
}
unsafe fn from_message_value<'msg>(msg: upb_MessageValue) -> View<'msg, ProtoStr> {

Loading…
Cancel
Save