diff --git a/rust/cpp.rs b/rust/cpp.rs index 6ef124070e..7c4998f875 100644 --- a/rust/cpp.rs +++ b/rust/cpp.rs @@ -197,7 +197,7 @@ pub fn copy_bytes_in_arena_if_needed_by_runtime<'a>( /// must be different fields, and not be in the same oneof. As such, a `Mut` /// cannot be `Clone` but *can* reborrow itself with `.as_mut()`, which /// converts `&'b mut Mut<'a, T>` to `Mut<'b, T>`. -#[derive(Clone, Copy)] +#[derive(Debug)] pub struct RepeatedField<'msg, T: ?Sized> { inner: RepeatedFieldInner<'msg>, _phantom: PhantomData<&'msg mut T>, @@ -206,7 +206,7 @@ pub struct RepeatedField<'msg, T: ?Sized> { /// CPP runtime-specific arguments for initializing a RepeatedField. /// See RepeatedField comment about mutation invariants for when this type can /// be copied. -#[derive(Clone, Copy)] +#[derive(Clone, Copy, Debug)] pub struct RepeatedFieldInner<'msg> { pub raw: RawRepeatedField, pub _phantom: PhantomData<&'msg ()>, @@ -217,7 +217,16 @@ impl<'msg, T: ?Sized> RepeatedField<'msg, T> { RepeatedField { inner, _phantom: PhantomData } } } -impl<'msg> RepeatedField<'msg, i32> {} + +// These use manual impls instead of derives to avoid unnecessary bounds on `T`. +// This problem is referred to as "perfect derive". +// https://smallcultfollowing.com/babysteps/blog/2022/04/12/implied-bounds-and-perfect-derive/ +impl<'msg, T: ?Sized> Copy for RepeatedField<'msg, T> {} +impl<'msg, T: ?Sized> Clone for RepeatedField<'msg, T> { + fn clone(&self) -> RepeatedField<'msg, T> { + *self + } +} pub trait RepeatedScalarOps { fn new_repeated_field() -> RawRepeatedField; @@ -225,6 +234,7 @@ pub trait RepeatedScalarOps { fn len(f: RawRepeatedField) -> usize; fn get(f: RawRepeatedField, i: usize) -> Self; fn set(f: RawRepeatedField, i: usize, v: Self); + fn copy_from(src: RawRepeatedField, dst: RawRepeatedField); } macro_rules! impl_repeated_scalar_ops { @@ -236,6 +246,7 @@ macro_rules! impl_repeated_scalar_ops { fn [< __pb_rust_RepeatedField_ $t _size >](f: RawRepeatedField) -> usize; fn [< __pb_rust_RepeatedField_ $t _get >](f: RawRepeatedField, i: usize) -> $t; fn [< __pb_rust_RepeatedField_ $t _set >](f: RawRepeatedField, i: usize, v: $t); + fn [< __pb_rust_RepeatedField_ $t _copy_from >](src: RawRepeatedField, dst: RawRepeatedField); } impl RepeatedScalarOps for $t { fn new_repeated_field() -> RawRepeatedField { @@ -253,6 +264,9 @@ macro_rules! impl_repeated_scalar_ops { fn set(f: RawRepeatedField, i: usize, v: Self) { unsafe { [< __pb_rust_RepeatedField_ $t _set >](f, i, v) } } + fn copy_from(src: RawRepeatedField, dst: RawRepeatedField) { + unsafe { [< __pb_rust_RepeatedField_ $t _copy_from >](src, dst) } + } } )* } }; @@ -292,6 +306,9 @@ impl<'msg, T: RepeatedScalarOps> RepeatedField<'msg, T> { } T::set(self.inner.raw, index, val) } + pub fn copy_from(&mut self, src: &RepeatedField<'_, T>) { + T::copy_from(src.inner.raw, self.inner.raw) + } } #[cfg(test)] diff --git a/rust/cpp_kernel/cpp_api.cc b/rust/cpp_kernel/cpp_api.cc index 8ff79d8fa9..1381611064 100644 --- a/rust/cpp_kernel/cpp_api.cc +++ b/rust/cpp_kernel/cpp_api.cc @@ -2,25 +2,29 @@ extern "C" { -#define expose_repeated_field_methods(ty, rust_ty) \ - google::protobuf::RepeatedField* __pb_rust_RepeatedField_##rust_ty##_new() { \ - return new google::protobuf::RepeatedField(); \ - } \ - void __pb_rust_RepeatedField_##rust_ty##_add(google::protobuf::RepeatedField* r, \ - ty val) { \ - r->Add(val); \ - } \ - size_t __pb_rust_RepeatedField_##rust_ty##_size( \ - google::protobuf::RepeatedField* r) { \ - return r->size(); \ - } \ - ty __pb_rust_RepeatedField_##rust_ty##_get(google::protobuf::RepeatedField* r, \ - size_t index) { \ - return r->Get(index); \ - } \ - void __pb_rust_RepeatedField_##rust_ty##_set(google::protobuf::RepeatedField* r, \ - size_t index, ty val) { \ - return r->Set(index, val); \ +#define expose_repeated_field_methods(ty, rust_ty) \ + google::protobuf::RepeatedField* __pb_rust_RepeatedField_##rust_ty##_new() { \ + return new google::protobuf::RepeatedField(); \ + } \ + void __pb_rust_RepeatedField_##rust_ty##_add(google::protobuf::RepeatedField* r, \ + ty val) { \ + r->Add(val); \ + } \ + size_t __pb_rust_RepeatedField_##rust_ty##_size( \ + google::protobuf::RepeatedField* r) { \ + return r->size(); \ + } \ + ty __pb_rust_RepeatedField_##rust_ty##_get(google::protobuf::RepeatedField* r, \ + size_t index) { \ + return r->Get(index); \ + } \ + void __pb_rust_RepeatedField_##rust_ty##_set(google::protobuf::RepeatedField* r, \ + size_t index, ty val) { \ + return r->Set(index, val); \ + } \ + void __pb_rust_RepeatedField_##rust_ty##_copy_from( \ + google::protobuf::RepeatedField const& src, google::protobuf::RepeatedField& dst) { \ + dst.CopyFrom(src); \ } expose_repeated_field_methods(int32_t, i32); diff --git a/rust/primitive.rs b/rust/primitive.rs index e69afe66f4..dd2fec18d4 100644 --- a/rust/primitive.rs +++ b/rust/primitive.rs @@ -7,6 +7,7 @@ use crate::__internal::Private; use crate::__runtime::InnerPrimitiveMut; +use crate::repeated::RepeatedMut; use crate::vtable::{ PrimitiveOptionalMutVTable, PrimitiveVTable, ProxiedWithRawOptionalVTable, ProxiedWithRawVTable, RawVTableOptionalMutatorData, @@ -14,18 +15,31 @@ use crate::vtable::{ use crate::{Mut, MutProxy, Proxied, ProxiedWithPresence, SettableValue, View, ViewProxy}; #[derive(Debug)] -pub struct PrimitiveMut<'a, T: ProxiedWithRawVTable> { +pub struct SingularPrimitiveMut<'a, T: ProxiedWithRawVTable> { inner: InnerPrimitiveMut<'a, T>, } +#[derive(Debug)] +pub enum PrimitiveMut<'a, T: ProxiedWithRawVTable> { + Singular(SingularPrimitiveMut<'a, T>), + Repeated(RepeatedMut<'a, T>, usize), +} + impl<'a, T: ProxiedWithRawVTable> PrimitiveMut<'a, T> { + #[doc(hidden)] + pub fn from_singular(_private: Private, inner: InnerPrimitiveMut<'a, T>) -> Self { + PrimitiveMut::Singular(SingularPrimitiveMut::from_inner(_private, inner)) + } +} + +impl<'a, T: ProxiedWithRawVTable> SingularPrimitiveMut<'a, T> { #[doc(hidden)] pub fn from_inner(_private: Private, inner: InnerPrimitiveMut<'a, T>) -> Self { Self { inner } } } -unsafe impl<'a, T: ProxiedWithRawVTable> Sync for PrimitiveMut<'a, T> {} +unsafe impl<'a, T: ProxiedWithRawVTable> Sync for SingularPrimitiveMut<'a, T> {} macro_rules! impl_singular_primitives { ($($t:ty),*) => { @@ -47,6 +61,29 @@ macro_rules! impl_singular_primitives { } } + impl<'a> PrimitiveMut<'a, $t> { + pub fn get(&self) -> View<'_, $t> { + match self { + PrimitiveMut::Singular(s) => { + s.get() + } + PrimitiveMut::Repeated(r, i) => { + r.get().get(*i).unwrap() + } + } + } + + pub fn set(&mut self, val: impl SettableValue<$t>) { + val.set_on(Private, self.as_mut()); + } + + pub fn clear(&mut self) { + // The default value for a boolean field is false and 0 for numerical types. It + // matches the Rust default values for corresponding types. Let's use this fact. + SettableValue::<$t>::set_on(<$t>::default(), Private, MutProxy::as_mut(self)); + } + } + impl<'a> ViewProxy<'a> for PrimitiveMut<'a, $t> { type Proxied = $t; @@ -61,7 +98,14 @@ macro_rules! impl_singular_primitives { impl<'a> MutProxy<'a> for PrimitiveMut<'a, $t> { fn as_mut(&mut self) -> Mut<'_, Self::Proxied> { - PrimitiveMut::from_inner(Private, self.inner) + match self { + PrimitiveMut::Singular(s) => { + PrimitiveMut::Singular(s.as_mut()) + } + PrimitiveMut::Repeated(r, i) => { + PrimitiveMut::Repeated(r.as_mut(), *i) + } + } } fn into_mut<'shorter>(self) -> Mut<'shorter, Self::Proxied> @@ -73,23 +117,23 @@ macro_rules! impl_singular_primitives { impl SettableValue<$t> for $t { fn set_on(self, _private: Private, mutator: Mut<'_, $t>) { - unsafe { (mutator.inner).set(self) }; + match mutator { + PrimitiveMut::Singular(s) => { + unsafe { (s.inner).set(self) }; + } + PrimitiveMut::Repeated(mut r, i) => { + r.set(i, self); + } + } } } - impl<'a> PrimitiveMut<'a, $t> { - pub fn set(&mut self, val: impl SettableValue<$t>) { - val.set_on(Private, self.as_mut()); - } - + impl<'a> SingularPrimitiveMut<'a, $t> { pub fn get(&self) -> $t { self.inner.get() } - - pub fn clear(&mut self) { - // The default value for a boolean field is false and 0 for numerical types. It - // matches the Rust default values for corresponding types. Let's use this fact. - SettableValue::<$t>::set_on(<$t>::default(), Private, MutProxy::as_mut(self)); + pub fn as_mut(&mut self) -> SingularPrimitiveMut<'_, $t> { + SingularPrimitiveMut::from_inner(Private, self.inner) } } @@ -104,7 +148,7 @@ macro_rules! impl_singular_primitives { } fn make_mut(_private: Private, inner: InnerPrimitiveMut<'_, Self>) -> Mut<'_, Self> { - PrimitiveMut::from_inner(Private, inner) + PrimitiveMut::Singular(SingularPrimitiveMut::from_inner(Private, inner)) } } @@ -121,7 +165,7 @@ macro_rules! impl_singular_primitives { fn set_absent_to_default( absent_mutator: Self::AbsentMutData<'_>, ) -> Self::PresentMutData<'_> { - absent_mutator.set_absent_to_default() + absent_mutator.set_absent_to_default() } } diff --git a/rust/repeated.rs b/rust/repeated.rs index b824de8399..1abb7152ba 100644 --- a/rust/repeated.rs +++ b/rust/repeated.rs @@ -12,8 +12,11 @@ use std::marker::PhantomData; use crate::{ + Mut, MutProxy, Proxied, SettableValue, View, ViewProxy, __internal::{Private, RawRepeatedField}, __runtime::{RepeatedField, RepeatedFieldInner}, + primitive::PrimitiveMut, + vtable::ProxiedWithRawVTable, }; #[derive(Clone, Copy)] @@ -31,6 +34,9 @@ pub struct RepeatedView<'a, T: ?Sized> { inner: RepeatedField<'a, T>, } +unsafe impl<'a, T: ProxiedWithRawVTable> Sync for RepeatedView<'a, T> {} +unsafe impl<'a, T: ProxiedWithRawVTable> Send for RepeatedView<'a, T> {} + impl<'msg, T: ?Sized> RepeatedView<'msg, T> { pub fn from_inner(_private: Private, inner: RepeatedFieldInner<'msg>) -> Self { Self { inner: RepeatedField::<'msg>::from_inner(_private, inner) } @@ -49,14 +55,20 @@ impl<'a, T> std::fmt::Debug for RepeatedView<'a, T> { } #[repr(transparent)] +#[derive(Debug)] pub struct RepeatedMut<'a, T: ?Sized> { inner: RepeatedField<'a, T>, } +unsafe impl<'a, T: ProxiedWithRawVTable> Sync for RepeatedMut<'a, T> {} + impl<'msg, T: ?Sized> RepeatedMut<'msg, T> { pub fn from_inner(_private: Private, inner: RepeatedFieldInner<'msg>) -> Self { Self { inner: RepeatedField::from_inner(_private, inner) } } + pub fn as_mut(&self) -> RepeatedMut<'_, T> { + Self { inner: self.inner } + } } impl<'a, T> std::ops::Deref for RepeatedMut<'a, T> { @@ -70,9 +82,67 @@ impl<'a, T> std::ops::Deref for RepeatedMut<'a, T> { } } +pub struct RepeatedFieldIterMut<'a, T> { + inner: RepeatedMut<'a, T>, + current_index: usize, +} + +pub struct Repeated(PhantomData); + macro_rules! impl_repeated_primitives { ($($t:ty),*) => { $( + impl Proxied for Repeated<$t> { + type View<'a> = RepeatedView<'a, $t>; + type Mut<'a> = RepeatedMut<'a, $t>; + } + + impl<'a> ViewProxy<'a> for RepeatedView<'a, $t> { + type Proxied = Repeated<$t>; + + fn as_view(&self) -> View<'_, Self::Proxied> { + *self + } + + fn into_view<'shorter>(self) -> View<'shorter, Self::Proxied> + where 'a: 'shorter, + { + RepeatedView { inner: self.inner } + } + } + + impl<'a> ViewProxy<'a> for RepeatedMut<'a, $t> { + type Proxied = Repeated<$t>; + + fn as_view(&self) -> View<'_, Self::Proxied> { + **self + } + + fn into_view<'shorter>(self) -> View<'shorter, Self::Proxied> + where 'a: 'shorter, + { + *self.into_mut::<'shorter>() + } + } + + impl<'a> MutProxy<'a> for RepeatedMut<'a, $t> { + fn as_mut(&mut self) -> Mut<'_, Self::Proxied> { + RepeatedMut { inner: self.inner } + } + + fn into_mut<'shorter>(self) -> Mut<'shorter, Self::Proxied> + where 'a: 'shorter, + { + RepeatedMut { inner: self.inner } + } + } + + impl <'a> SettableValue> for RepeatedView<'a, $t> { + fn set_on(self, _private: Private, mut mutator: Mut<'_, Repeated<$t>>) { + mutator.copy_from(self); + } + } + impl<'a> RepeatedView<'a, $t> { pub fn len(&self) -> usize { self.inner.len() @@ -83,6 +153,9 @@ macro_rules! impl_repeated_primitives { pub fn get(&self, index: usize) -> Option<$t> { self.inner.get(index) } + pub fn iter(&self) -> RepeatedFieldIter<'_, $t> { + (*self).into_iter() + } } impl<'a> RepeatedMut<'a, $t> { @@ -92,6 +165,21 @@ macro_rules! impl_repeated_primitives { pub fn set(&mut self, index: usize, val: $t) { self.inner.set(index, val) } + pub fn get_mut(&mut self, index: usize) -> Option> { + if index >= self.len() { + return None; + } + Some(PrimitiveMut::Repeated(self.as_mut(), index)) + } + pub fn iter(&self) -> RepeatedFieldIter<'_, $t> { + self.as_view().into_iter() + } + pub fn iter_mut(&mut self) -> RepeatedFieldIterMut<'_, $t> { + self.as_mut().into_iter() + } + pub fn copy_from(&mut self, src: RepeatedView<'_, $t>) { + self.inner.copy_from(&src.inner); + } } impl<'a> std::iter::Iterator for RepeatedFieldIter<'a, $t> { @@ -112,6 +200,32 @@ macro_rules! impl_repeated_primitives { RepeatedFieldIter { inner: self.inner, current_index: 0 } } } + + impl <'a> std::iter::Iterator for RepeatedFieldIterMut<'a, $t> { + type Item = Mut<'a, $t>; + fn next(&mut self) -> Option { + if self.current_index >= self.inner.len() { + return None; + } + let elem = PrimitiveMut::Repeated( + // While this appears to allow mutable aliasing + // (multiple `Self::Item`s can co-exist), each `Item` + // only references a specific unique index. + RepeatedMut{ inner: self.inner.inner }, + self.current_index, + ); + self.current_index += 1; + Some(elem) + } + } + + impl<'a> std::iter::IntoIterator for RepeatedMut<'a, $t> { + type Item = Mut<'a, $t>; + type IntoIter = RepeatedFieldIterMut<'a, $t>; + fn into_iter(self) -> Self::IntoIter { + RepeatedFieldIterMut { inner: self, current_index: 0 } + } + } )* } } diff --git a/rust/shared.rs b/rust/shared.rs index f8a9d117d9..7c3a3d163d 100644 --- a/rust/shared.rs +++ b/rust/shared.rs @@ -18,7 +18,7 @@ use std::fmt; #[doc(hidden)] pub mod __public { pub use crate::optional::{AbsentField, FieldEntry, Optional, PresentField}; - pub use crate::primitive::PrimitiveMut; + pub use crate::primitive::{PrimitiveMut, SingularPrimitiveMut}; pub use crate::proxied::{ Mut, MutProxy, Proxied, ProxiedWithPresence, SettableValue, View, ViewProxy, }; diff --git a/rust/test/shared/accessors_test.rs b/rust/test/shared/accessors_test.rs index ef33f21bf3..f1f55f663d 100644 --- a/rust/test/shared/accessors_test.rs +++ b/rust/test/shared/accessors_test.rs @@ -714,7 +714,47 @@ macro_rules! generate_repeated_numeric_test { assert_that!(mutator.get(0), some(eq(2 as $t))); mutator.push(1 as $t); - assert_that!(mutator.into_iter().collect::>(), eq(vec![2 as $t, 1 as $t])); + mutator.push(3 as $t); + assert_that!(mutator.get_mut(2).is_some(), eq(true)); + let mut mut_elem = mutator.get_mut(2).unwrap(); + mut_elem.set(4 as $t); + assert_that!(mut_elem.get(), eq(4 as $t)); + mut_elem.clear(); + assert_that!(mut_elem.get(), eq(0 as $t)); + + assert_that!( + mutator.iter().collect::>(), + eq(vec![2 as $t, 1 as $t, 0 as $t]) + ); + assert_that!( + (*mutator).into_iter().collect::>(), + eq(vec![2 as $t, 1 as $t, 0 as $t]) + ); + + for mut mutable_elem in msg.[]() { + mutable_elem.set(0 as $t); + } + assert_that!( + msg.[]().iter().all(|v| v == (0 as $t)), + eq(true) + ); + } + + #[test] + fn [< test_repeated_ $field _set >]() { + let mut msg = TestAllTypes::new(); + let mut mutator = msg.[](); + let mut msg2 = TestAllTypes::new(); + let mut mutator2 = msg2.[](); + for i in 0..5 { + mutator2.push(i as $t); + } + protobuf::MutProxy::set(&mut mutator, *mutator2); + + assert_that!( + mutator.iter().collect::>(), + eq(mutator2.iter().collect::>()) + ); } )* } }; @@ -742,5 +782,34 @@ fn test_repeated_bool_accessors() { mutator.set(0, false); assert_that!(mutator.get(0), some(eq(false))); mutator.push(true); - assert_that!(mutator.into_iter().collect::>(), eq(vec![false, true])); + + mutator.push(false); + assert_that!(mutator.get_mut(2), pat!(Some(_))); + let mut mut_elem = mutator.get_mut(2).unwrap(); + mut_elem.set(true); + assert_that!(mut_elem.get(), eq(true)); + mut_elem.clear(); + assert_that!(mut_elem.get(), eq(false)); + + assert_that!(mutator.iter().collect::>(), eq(vec![false, true, false])); + assert_that!((*mutator).into_iter().collect::>(), eq(vec![false, true, false])); + + for mut mutable_elem in msg.repeated_bool_mut() { + mutable_elem.set(false); + } + assert_that!(msg.repeated_bool().iter().all(|v| v), eq(false)); +} + +#[test] +fn test_repeated_bool_set() { + let mut msg = TestAllTypes::new(); + let mut mutator = msg.repeated_bool_mut(); + let mut msg2 = TestAllTypes::new(); + let mut mutator2 = msg2.repeated_bool_mut(); + for _ in 0..5 { + mutator2.push(true); + } + protobuf::MutProxy::set(&mut mutator, *mutator2); + + assert_that!(mutator.iter().collect::>(), eq(mutator2.iter().collect::>())); } diff --git a/rust/upb.rs b/rust/upb.rs index 24edbce347..b6fd7b322a 100644 --- a/rust/upb.rs +++ b/rust/upb.rs @@ -293,12 +293,22 @@ pub struct RepeatedFieldInner<'msg> { pub arena: &'msg Arena, } -#[derive(Clone, Copy, Debug)] +#[derive(Debug)] pub struct RepeatedField<'msg, T: ?Sized> { inner: RepeatedFieldInner<'msg>, _phantom: PhantomData<&'msg mut T>, } +// These use manual impls instead of derives to avoid unnecessary bounds on `T`. +// This problem is referred to as "perfect derive". +// https://smallcultfollowing.com/babysteps/blog/2022/04/12/implied-bounds-and-perfect-derive/ +impl<'msg, T: ?Sized> Copy for RepeatedField<'msg, T> {} +impl<'msg, T: ?Sized> Clone for RepeatedField<'msg, T> { + fn clone(&self) -> RepeatedField<'msg, T> { + *self + } +} + impl<'msg, T: ?Sized> RepeatedField<'msg, T> { pub fn len(&self) -> usize { unsafe { upb_Array_Size(self.inner.raw) } @@ -352,6 +362,7 @@ extern "C" { fn upb_Array_Set(arr: RawRepeatedField, i: usize, val: upb_MessageValue); fn upb_Array_Get(arr: RawRepeatedField, i: usize) -> upb_MessageValue; fn upb_Array_Append(arr: RawRepeatedField, val: upb_MessageValue, arena: RawArena); + fn upb_Array_Resize(arr: RawRepeatedField, size: usize, arena: RawArena); } macro_rules! impl_repeated_primitives { @@ -392,6 +403,19 @@ macro_rules! impl_repeated_primitives { upb_MessageValue { $union_field: val }, ) } } + pub fn copy_from(&mut self, src: &RepeatedField<'_, $rs_type>) { + // TODO: Optimize this copy_from implementation using memcopy. + // NOTE: `src` cannot be `self` because this would violate borrowing rules. + unsafe { upb_Array_Resize(self.inner.raw, 0, self.inner.arena.raw()) }; + // `upb_Array_DeepClone` is not used here because it returns + // a new `upb_Array*`. The contained `RawRepeatedField` must + // then be set to this new pointer, but other copies of this + // pointer may exist because of re-borrowed `RepeatedMut`s. + // Alternatively, a `clone_into` method could be exposed by upb. + for i in 0..src.len() { + self.push(src.get(i).unwrap()); + } + } } )* } diff --git a/src/google/protobuf/compiler/rust/accessors/singular_scalar.cc b/src/google/protobuf/compiler/rust/accessors/singular_scalar.cc index 818dc7ce9b..b7b7a55643 100644 --- a/src/google/protobuf/compiler/rust/accessors/singular_scalar.cc +++ b/src/google/protobuf/compiler/rust/accessors/singular_scalar.cc @@ -162,7 +162,7 @@ void SingularScalar::InMsgImpl(Context field) const { )rs"); } else { field.Emit({}, R"rs( - pub fn r#$field$_mut(&mut self) -> $pb$::PrimitiveMut<'_, $Scalar$> { + pub fn r#$field$_mut(&mut self) -> $pb$::Mut<'_, $Scalar$> { static VTABLE: $pbi$::PrimitiveVTable<$Scalar$> = $pbi$::PrimitiveVTable::new( $pbi$::Private, @@ -170,7 +170,7 @@ void SingularScalar::InMsgImpl(Context field) const { $setter_thunk$, ); - $pb$::PrimitiveMut::from_inner( + $pb$::PrimitiveMut::from_singular( $pbi$::Private, unsafe { $pbi$::RawVTableMutator::new(