From 9c994245f2eff656f4bede90d5d2fd2ffb01a2fd Mon Sep 17 00:00:00 2001 From: Derek Benson Date: Mon, 15 Jul 2024 12:36:48 -0700 Subject: [PATCH] implement repeated_new and repeated_free for enums and messages. Drop the default impl since it is now required PiperOrigin-RevId: 652567229 --- rust/cpp.rs | 33 ++++++++++++- rust/repeated.rs | 10 ++-- rust/test/shared/accessors_repeated_test.rs | 21 +++++++++ rust/upb.rs | 32 +++++++++++-- src/google/protobuf/compiler/rust/enum.cc | 8 ++++ src/google/protobuf/compiler/rust/message.cc | 49 ++++++++++++++++++++ 6 files changed, 141 insertions(+), 12 deletions(-) diff --git a/rust/cpp.rs b/rust/cpp.rs index c5257ba19c..d8e4df041a 100644 --- a/rust/cpp.rs +++ b/rust/cpp.rs @@ -371,6 +371,13 @@ impl InnerRepeated { pub fn raw(&self) -> RawRepeatedField { self.raw } + + /// # Safety + /// - `raw` must be a valid `proto2::RepeatedField*` or + /// `proto2::RepeatedPtrField*`. + pub unsafe fn from_raw(_: Private, raw: RawRepeatedField) -> Self { + Self { raw } + } } /// The raw type-erased pointer version of `RepeatedMut`. @@ -480,7 +487,7 @@ macro_rules! impl_repeated_primitives { #[allow(dead_code)] #[inline] fn repeated_new(_: Private) -> Repeated<$t> { - Repeated::from_inner(InnerRepeated { + Repeated::from_inner(Private, InnerRepeated { raw: unsafe { $new_thunk() } }) } @@ -581,6 +588,30 @@ pub fn reserve_enum_repeated_mut( ProxiedInRepeated::repeated_reserve(int_repeated, additional); } +pub fn new_enum_repeated(_: Private) -> Repeated { + let int_repeated = Repeated::::new(); + let raw = int_repeated.inner.raw(); + std::mem::forget(int_repeated); + unsafe { Repeated::from_inner(Private, InnerRepeated::from_raw(Private, raw)) } +} + +/// Cast a `RepeatedMut` to `RepeatedMut` and call +/// repeated_free. +/// # Safety +/// - The passed in `&mut Repeated` must not be used after this function is +/// called. +pub unsafe fn free_enum_repeated( + _: Private, + repeated: &mut Repeated, +) { + unsafe { + let mut int_r: Repeated = + Repeated::from_inner(Private, InnerRepeated::from_raw(Private, repeated.inner.raw())); + ProxiedInRepeated::repeated_free(Private, &mut int_r); + std::mem::forget(int_r); + } +} + #[derive(Debug)] pub struct InnerMap { pub(crate) raw: RawMap, diff --git a/rust/repeated.rs b/rust/repeated.rs index 4044b8c67e..99c935b4f4 100644 --- a/rust/repeated.rs +++ b/rust/repeated.rs @@ -287,18 +287,14 @@ where pub unsafe trait ProxiedInRepeated: Proxied { /// Constructs a new owned `Repeated` field. #[doc(hidden)] - fn repeated_new(_private: Private) -> Repeated { - unimplemented!("not required") - } + fn repeated_new(_private: Private) -> Repeated; /// Frees the repeated field in-place, for use in `Drop`. /// /// # Safety /// - After `repeated_free`, no other methods on the input are safe to call. #[doc(hidden)] - unsafe fn repeated_free(_private: Private, _repeated: &mut Repeated) { - unimplemented!("not required") - } + unsafe fn repeated_free(_private: Private, _repeated: &mut Repeated); /// Gets the length of the repeated field. fn repeated_len(repeated: View>) -> usize; @@ -365,7 +361,7 @@ impl Repeated { T::repeated_new(Private) } - pub(crate) fn from_inner(inner: InnerRepeated) -> Self { + pub fn from_inner(_private: Private, inner: InnerRepeated) -> Self { Self { inner, _phantom: PhantomData } } diff --git a/rust/test/shared/accessors_repeated_test.rs b/rust/test/shared/accessors_repeated_test.rs index cf3b9922ae..9f70074bf4 100644 --- a/rust/test/shared/accessors_repeated_test.rs +++ b/rust/test/shared/accessors_repeated_test.rs @@ -172,6 +172,18 @@ fn test_repeated_enum_accessors() { assert_that!(msg.repeated_nested_enum(), each(eq(NestedEnum::Foo))); } +#[test] +fn test_repeated_enum_set() { + use test_all_types::NestedEnum; + + let mut msg = TestAllTypes::new(); + msg.set_repeated_nested_enum([NestedEnum::Foo, NestedEnum::Bar, NestedEnum::Baz].into_iter()); + assert_that!( + msg.repeated_nested_enum(), + elements_are![eq(NestedEnum::Foo), eq(NestedEnum::Bar), eq(NestedEnum::Baz)] + ); +} + #[test] fn test_repeated_bool_set() { let mut msg = TestAllTypes::new(); @@ -218,6 +230,15 @@ fn test_repeated_message() { assert_that!(msg2.repeated_nested_message().len(), eq(0)); } +#[test] +fn test_repeated_message_setter() { + let mut msg = TestAllTypes::new(); + let mut nested = NestedMessage::new(); + nested.set_bb(1); + msg.set_repeated_nested_message([nested].into_iter()); + assert_that!(msg.repeated_nested_message().get(0).unwrap().bb(), eq(1)); +} + #[test] fn test_repeated_strings() { let mut older_msg = TestAllTypes::new(); diff --git a/rust/upb.rs b/rust/upb.rs index 5a8cdfac87..23f7563e78 100644 --- a/rust/upb.rs +++ b/rust/upb.rs @@ -189,6 +189,12 @@ impl InnerRepeated { pub fn arena(&self) -> &Arena { &self.arena } + + /// # Safety + /// - `raw` must be a valid `RawRepeatedField` + pub unsafe fn from_raw_parts(_: Private, raw: RawRepeatedField, arena: Arena) -> Self { + Self { raw, arena } + } } /// The raw type-erased pointer version of `RepeatedMut`. @@ -211,10 +217,10 @@ macro_rules! impl_repeated_base { #[inline] fn repeated_new(_: Private) -> Repeated<$t> { let arena = Arena::new(); - Repeated::from_inner(InnerRepeated { - raw: unsafe { upb_Array_New(arena.raw(), $upb_tag) }, - arena, - }) + Repeated::from_inner( + Private, + InnerRepeated { raw: unsafe { upb_Array_New(arena.raw(), $upb_tag) }, arena }, + ) } #[allow(dead_code)] unsafe fn repeated_free(_: Private, _f: &mut Repeated<$t>) { @@ -438,6 +444,24 @@ pub fn reserve_enum_repeated_mut( ProxiedInRepeated::repeated_reserve(int_repeated, additional); } +pub fn new_enum_repeated(_: Private) -> Repeated { + let arena = Arena::new(); + // SAFETY: + // - `upb_Array_New` is unsafe but assumed to be sound when called on a valid + // arena. + unsafe { + let raw = upb_Array_New(arena.raw(), upb::CType::Int32); + Repeated::from_inner(Private, InnerRepeated::from_raw_parts(Private, raw, arena)) + } +} + +pub fn free_enum_repeated( + _private: Private, + _repeated: &mut Repeated, +) { + // No-op: the memory will be dropped by the arena. +} + /// Returns a static empty RepeatedView. pub fn empty_array() -> RepeatedView<'static, T> { // TODO: Consider creating a static empty array in C. diff --git a/src/google/protobuf/compiler/rust/enum.cc b/src/google/protobuf/compiler/rust/enum.cc index 6c2117928d..74613aa8f2 100644 --- a/src/google/protobuf/compiler/rust/enum.cc +++ b/src/google/protobuf/compiler/rust/enum.cc @@ -418,6 +418,14 @@ void GenerateEnumDefinition(Context& ctx, const EnumDescriptor& desc) { } unsafe impl $pb$::ProxiedInRepeated for $name$ { + fn repeated_new(_private: $pbi$::Private) -> $pb$::Repeated { + $pbr$::new_enum_repeated($pbi$::Private) + } + + unsafe fn repeated_free(_private: $pbi$::Private, _f: &mut $pb$::Repeated) { + $pbr$::free_enum_repeated($pbi$::Private, _f) + } + fn repeated_len(r: $pb$::View<$pb$::Repeated>) -> usize { $pbr$::cast_enum_repeated_view($pbi$::Private, r).len() } diff --git a/src/google/protobuf/compiler/rust/message.cc b/src/google/protobuf/compiler/rust/message.cc index 23875cbe9f..44f6ca8be3 100644 --- a/src/google/protobuf/compiler/rust/message.cc +++ b/src/google/protobuf/compiler/rust/message.cc @@ -184,6 +184,8 @@ void MessageExterns(Context& ctx, const Descriptor& msg) { {"parse_thunk", ThunkName(ctx, msg, "parse")}, {"copy_from_thunk", ThunkName(ctx, msg, "copy_from")}, {"merge_from_thunk", ThunkName(ctx, msg, "merge_from")}, + {"repeated_new_thunk", ThunkName(ctx, msg, "repeated_new")}, + {"repeated_free_thunk", ThunkName(ctx, msg, "repeated_free")}, {"repeated_len_thunk", ThunkName(ctx, msg, "repeated_len")}, {"repeated_get_thunk", ThunkName(ctx, msg, "repeated_get")}, {"repeated_get_mut_thunk", @@ -202,6 +204,8 @@ void MessageExterns(Context& ctx, const Descriptor& msg) { fn $parse_thunk$(raw_msg: $pbr$::RawMessage, data: $pbr$::SerializedData) -> bool; fn $copy_from_thunk$(dst: $pbr$::RawMessage, src: $pbr$::RawMessage); fn $merge_from_thunk$(dst: $pbr$::RawMessage, src: $pbr$::RawMessage); + fn $repeated_new_thunk$() -> $pbr$::RawRepeatedField; + fn $repeated_free_thunk$(raw: $pbr$::RawRepeatedField); fn $repeated_len_thunk$(raw: $pbr$::RawRepeatedField) -> usize; fn $repeated_add_thunk$(raw: $pbr$::RawRepeatedField) -> $pbr$::RawMessage; fn $repeated_get_thunk$(raw: $pbr$::RawRepeatedField, index: usize) -> $pbr$::RawMessage; @@ -363,6 +367,8 @@ void MessageProxiedInRepeated(Context& ctx, const Descriptor& msg) { {"repeated_get_mut_thunk", ThunkName(ctx, msg, "repeated_get_mut")}, {"repeated_add_thunk", ThunkName(ctx, msg, "repeated_add")}, + {"repeated_new_thunk", ThunkName(ctx, msg, "repeated_new")}, + {"repeated_free_thunk", ThunkName(ctx, msg, "repeated_free")}, {"repeated_clear_thunk", ThunkName(ctx, msg, "repeated_clear")}, {"repeated_copy_from_thunk", ThunkName(ctx, msg, "repeated_copy_from")}, @@ -371,6 +377,24 @@ void MessageProxiedInRepeated(Context& ctx, const Descriptor& msg) { }, R"rs( unsafe impl $pb$::ProxiedInRepeated for $Msg$ { + fn repeated_new(_private: $pbi$::Private) -> $pb$::Repeated { + // SAFETY: + // - The thunk returns an unaliased and valid `RepeatedPtrField*` + unsafe { + $pb$::Repeated::from_inner($pbi$::Private, + $pbr$::InnerRepeated::from_raw($pbi$::Private, + $repeated_new_thunk$() + ) + ) + } + } + + unsafe fn repeated_free(_private: $pbi$::Private, f: &mut $pb$::Repeated) { + // SAFETY + // - `f.raw()` is a valid `RepeatedPtrField*`. + unsafe { $repeated_free_thunk$(f.as_view().as_raw($pbi$::Private)) } + } + fn repeated_len(f: $pb$::View<$pb$::Repeated>) -> usize { // SAFETY: `f.as_raw()` is a valid `RepeatedPtrField*`. unsafe { $repeated_len_thunk$(f.as_raw($pbi$::Private)) } @@ -450,6 +474,21 @@ void MessageProxiedInRepeated(Context& ctx, const Descriptor& msg) { }, R"rs( unsafe impl $pb$::ProxiedInRepeated for $Msg$ { + fn repeated_new(_private: $pbi$::Private) -> $pb$::Repeated { + let arena = $pbr$::Arena::new(); + unsafe { + $pb$::Repeated::from_inner($pbi$::Private, $pbr$::InnerRepeated::from_raw_parts( + $pbi$::Private, + $pbr$::upb_Array_New(arena.raw(), $pbr$::CType::Message), + arena, + )) + } + } + + unsafe fn repeated_free(_private: $pbi$::Private, _f: &mut $pb$::Repeated) { + // No-op: the memory will be dropped by the arena. + } + fn repeated_len(f: $pb$::View<$pb$::Repeated>) -> usize { // SAFETY: `f.as_raw()` is a valid `upb_Array*`. unsafe { $pbr$::upb_Array_Size(f.as_raw($pbi$::Private)) } @@ -1215,6 +1254,8 @@ void GenerateThunksCc(Context& ctx, const Descriptor& msg) { {"parse_thunk", ThunkName(ctx, msg, "parse")}, {"copy_from_thunk", ThunkName(ctx, msg, "copy_from")}, {"merge_from_thunk", ThunkName(ctx, msg, "merge_from")}, + {"repeated_new_thunk", ThunkName(ctx, msg, "repeated_new")}, + {"repeated_free_thunk", ThunkName(ctx, msg, "repeated_free")}, {"repeated_len_thunk", ThunkName(ctx, msg, "repeated_len")}, {"repeated_get_thunk", ThunkName(ctx, msg, "repeated_get")}, {"repeated_get_mut_thunk", ThunkName(ctx, msg, "repeated_get_mut")}, @@ -1267,6 +1308,14 @@ void GenerateThunksCc(Context& ctx, const Descriptor& msg) { dst->MergeFrom(*src); } + void* $repeated_new_thunk$() { + return new google::protobuf::RepeatedPtrField<$QualifiedMsg$>(); + } + + void $repeated_free_thunk$(void* ptr) { + delete static_cast*>(ptr); + } + size_t $repeated_len_thunk$(google::protobuf::RepeatedPtrField<$QualifiedMsg$>* field) { return field->size(); }