From df376c807c40b2281dccea754561c42cc4157a6c Mon Sep 17 00:00:00 2001 From: Kevin King Date: Wed, 17 Jan 2024 11:41:53 -0800 Subject: [PATCH] Implement ProxiedInRepeated for Messages PiperOrigin-RevId: 599241012 --- rust/upb.rs | 46 +++- src/google/protobuf/compiler/rust/message.cc | 230 ++++++++++++++++++- 2 files changed, 269 insertions(+), 7 deletions(-) diff --git a/rust/upb.rs b/rust/upb.rs index 50aecdae60..c1875180ea 100644 --- a/rust/upb.rs +++ b/rust/upb.rs @@ -342,6 +342,11 @@ extern "C" { mini_table: *const OpaqueMiniTable, arena: RawArena, ); + pub fn upb_Message_DeepClone( + m: RawMessage, + mini_table: *const OpaqueMiniTable, + arena: RawArena, + ) -> Option; } /// The raw type-erased pointer version of `RepeatedMut`. @@ -416,11 +421,11 @@ pub enum UpbCType { extern "C" { fn upb_Array_New(a: RawArena, r#type: std::ffi::c_int) -> RawRepeatedField; - fn upb_Array_Size(arr: RawRepeatedField) -> usize; - fn upb_Array_Set(arr: RawRepeatedField, i: usize, val: upb_MessageValue); - fn upb_Array_Get(arr: RawRepeatedField, i: usize) -> upb_MessageValue; - fn upb_Array_Append(arr: RawRepeatedField, val: upb_MessageValue, arena: RawArena); - fn upb_Array_Resize(arr: RawRepeatedField, size: usize, arena: RawArena) -> bool; + pub fn upb_Array_Size(arr: RawRepeatedField) -> usize; + pub fn upb_Array_Set(arr: RawRepeatedField, i: usize, val: upb_MessageValue); + pub fn upb_Array_Get(arr: RawRepeatedField, i: usize) -> upb_MessageValue; + pub fn upb_Array_Append(arr: RawRepeatedField, val: upb_MessageValue, arena: RawArena); + pub fn upb_Array_Resize(arr: RawRepeatedField, size: usize, arena: RawArena) -> bool; fn upb_Array_MutableDataPtr(arr: RawRepeatedField) -> *mut std::ffi::c_void; fn upb_Array_DataPtr(arr: RawRepeatedField) -> *const std::ffi::c_void; pub fn upb_Array_GetMutable(arr: RawRepeatedField, i: usize) -> upb_MutableMessageValue; @@ -511,6 +516,37 @@ impl_repeated_primitives!( (u64, uint64_val, UpbCType::UInt64), ); +/// Copy the contents of `src` into `dest`. +/// +/// # Safety +/// - `minitable` must be a pointer to the minitable for message `T`. +pub unsafe fn repeated_message_copy_from( + src: View>, + mut dest: Mut>, + minitable: *const OpaqueMiniTable, +) { + // SAFETY: + // - `src.as_raw()` is a valid `const upb_Array*`. + // - `dest.as_raw()` is a valid `upb_Array*`. + // - Elements of `src` and have message minitable `$minitable$`. + unsafe { + let size = upb_Array_Size(src.as_raw(Private)); + if !upb_Array_Resize(dest.as_raw(Private), size, dest.raw_arena(Private)) { + panic!("upb_Array_Resize failed."); + } + for i in 0..size { + let src_msg = upb_Array_Get(src.as_raw(Private), i) + .msg_val + .expect("upb_Array* element should not be NULL"); + // Avoid the use of `upb_Array_DeepClone` as it creates an + // entirely new `upb_Array*` at a new memory address. + let cloned_msg = upb_Message_DeepClone(src_msg, minitable, dest.raw_arena(Private)) + .expect("upb_Message_DeepClone failed."); + upb_Array_Set(dest.as_raw(Private), i, upb_MessageValue { msg_val: Some(cloned_msg) }); + } + } +} + /// Cast a `RepeatedView` to `RepeatedView`. pub fn cast_enum_repeated_view( private: Private, diff --git a/src/google/protobuf/compiler/rust/message.cc b/src/google/protobuf/compiler/rust/message.cc index d46e4fc4d9..647bb369b4 100644 --- a/src/google/protobuf/compiler/rust/message.cc +++ b/src/google/protobuf/compiler/rust/message.cc @@ -134,6 +134,14 @@ void MessageExterns(Context& ctx, const Descriptor& msg) { {"serialize_thunk", ThunkName(ctx, msg, "serialize")}, {"deserialize_thunk", ThunkName(ctx, msg, "deserialize")}, {"copy_from_thunk", ThunkName(ctx, msg, "copy_from")}, + {"repeated_len_thunk", ThunkName(ctx, msg, "repeated_len")}, + {"repeated_get_thunk", ThunkName(ctx, msg, "repeated_get")}, + {"repeated_get_mut_thunk", + ThunkName(ctx, msg, "repeated_get_mut")}, + {"repeated_add_thunk", ThunkName(ctx, msg, "repeated_add")}, + {"repeated_clear_thunk", ThunkName(ctx, msg, "repeated_clear")}, + {"repeated_copy_from_thunk", + ThunkName(ctx, msg, "repeated_copy_from")}, }, R"rs( fn $new_thunk$() -> $pbi$::RawMessage; @@ -141,6 +149,12 @@ void MessageExterns(Context& ctx, const Descriptor& msg) { fn $serialize_thunk$(raw_msg: $pbi$::RawMessage) -> $pbr$::SerializedData; fn $deserialize_thunk$(raw_msg: $pbi$::RawMessage, data: $pbr$::SerializedData) -> bool; fn $copy_from_thunk$(dst: $pbi$::RawMessage, src: $pbi$::RawMessage); + fn $repeated_len_thunk$(raw: $pbi$::RawRepeatedField) -> usize; + fn $repeated_add_thunk$(raw: $pbi$::RawRepeatedField) -> $pbi$::RawMessage; + fn $repeated_get_thunk$(raw: $pbi$::RawRepeatedField, index: usize) -> $pbi$::RawMessage; + fn $repeated_get_mut_thunk$(raw: $pbi$::RawRepeatedField, index: usize) -> $pbi$::RawMessage; + fn $repeated_clear_thunk$(raw: $pbi$::RawRepeatedField); + fn $repeated_copy_from_thunk$(dst: $pbi$::RawRepeatedField, src: $pbi$::RawRepeatedField); )rs"); return; @@ -193,6 +207,7 @@ void MessageSettableValue(Context& ctx, const Descriptor& msg) { return; case Kernel::kUpb: + // TODO: Add owned SettableValue impl for upb messages. ctx.Emit({{"minitable", UpbMinitableName(msg)}}, R"rs( impl<'msg> $pb$::SettableValue<$Msg$> for $Msg$View<'msg> { fn set_on<'dst>( @@ -213,6 +228,184 @@ void MessageSettableValue(Context& ctx, const Descriptor& msg) { ABSL_LOG(FATAL) << "unreachable"; } +void MessageProxiedInRepeated(Context& ctx, const Descriptor& msg) { + switch (ctx.opts().kernel) { + case Kernel::kCpp: + ctx.Emit( + { + {"Msg", RsSafeName(msg.name())}, + {"copy_from_thunk", ThunkName(ctx, msg, "copy_from")}, + {"repeated_len_thunk", ThunkName(ctx, msg, "repeated_len")}, + {"repeated_get_thunk", ThunkName(ctx, msg, "repeated_get")}, + {"repeated_get_mut_thunk", + ThunkName(ctx, msg, "repeated_get_mut")}, + {"repeated_add_thunk", ThunkName(ctx, msg, "repeated_add")}, + {"repeated_clear_thunk", ThunkName(ctx, msg, "repeated_clear")}, + {"repeated_copy_from_thunk", + ThunkName(ctx, msg, "repeated_copy_from")}, + }, + R"rs( + unsafe impl $pb$::ProxiedInRepeated for $Msg$ { + fn repeated_len(f: $pb$::View<$pb$::Repeated>) -> usize { + // SAFETY: `f.as_raw()` is a valid `RepeatedPtrField*`. + unsafe { $repeated_len_thunk$(f.as_raw($pbi$::Private)) } + } + + unsafe fn repeated_set_unchecked( + mut f: $pb$::Mut<$pb$::Repeated>, + i: usize, + v: $pb$::View, + ) { + // SAFETY: + // - `f.as_raw()` is a valid `RepeatedPtrField*`. + // - `i < len(f)` is promised by caller. + // - `v.raw_msg()` is a valid `const Message&`. + unsafe { + $copy_from_thunk$( + $repeated_get_mut_thunk$(f.as_raw($pbi$::Private), i), + v.raw_msg(), + ); + } + } + + unsafe fn repeated_get_unchecked( + f: $pb$::View<$pb$::Repeated>, + i: usize, + ) -> $pb$::View { + // SAFETY: + // - `f.as_raw()` is a valid `const RepeatedPtrField&`. + // - `i < len(f)` is promised by caller. + let msg = unsafe { $repeated_get_thunk$(f.as_raw($pbi$::Private), i) }; + $pb$::View::::new($pbi$::Private, msg) + } + fn repeated_clear(mut f: $pb$::Mut<$pb$::Repeated>) { + // SAFETY: + // - `f.as_raw()` is a valid `RepeatedPtrField*`. + unsafe { $repeated_clear_thunk$(f.as_raw($pbi$::Private)) }; + } + + fn repeated_push(mut f: $pb$::Mut<$pb$::Repeated>, v: $pb$::View) { + // SAFETY: + // - `f.as_raw()` is a valid `RepeatedPtrField*`. + // - `v.raw_msg()` is a valid `const Message&`. + unsafe { + let new_elem = $repeated_add_thunk$(f.as_raw($pbi$::Private)); + $copy_from_thunk$(new_elem, v.raw_msg()); + } + } + + fn repeated_copy_from( + src: $pb$::View<$pb$::Repeated>, + mut dest: $pb$::Mut<$pb$::Repeated>, + ) { + // SAFETY: + // - `dest.as_raw()` is a valid `RepeatedPtrField*`. + // - `src.as_raw()` is a valid `const RepeatedPtrField&`. + unsafe { + $repeated_copy_from_thunk$(dest.as_raw($pbi$::Private), src.as_raw($pbi$::Private)); + } + } + } + + )rs"); + return; + case Kernel::kUpb: + ctx.Emit( + { + {"minitable", UpbMinitableName(msg)}, + {"new_thunk", ThunkName(ctx, msg, "new")}, + }, + R"rs( + unsafe impl $pb$::ProxiedInRepeated for $Msg$ { + fn repeated_len(f: $pb$::View<$pb$::Repeated>) -> usize { + // SAFETY: `f.as_raw()` is a valid `upb_Array*`. + unsafe { $pbr$::upb_Array_Size(f.as_raw($pbi$::Private)) } + } + unsafe fn repeated_set_unchecked( + mut f: $pb$::Mut<$pb$::Repeated>, + i: usize, + v: $pb$::View, + ) { + // SAFETY: + // - `f.as_raw()` is a valid `upb_Array*`. + // - `i < len(f)` is promised by the caller. + let mut dest_msg = unsafe { + $pbr$::upb_Array_GetMutable(f.as_raw($pbi$::Private), i).msg + }.expect("upb_Array* element should not be NULL"); + + // SAFETY: + // - `dest_msg` is a valid `upb_Message*`. + // - `v.raw_msg()` and `dest_msg` both have message minitable `$minitable$`. + unsafe { + $pbr$::upb_Message_DeepCopy( + dest_msg, + v.raw_msg(), + $std$::ptr::addr_of!($minitable$), + f.raw_arena($pbi$::Private), + ) + }; + } + + unsafe fn repeated_get_unchecked( + f: $pb$::View<$pb$::Repeated>, + i: usize, + ) -> $pb$::View { + // SAFETY: + // - `f.as_raw()` is a valid `const upb_Array*`. + // - `i < len(f)` is promised by the caller. + let msg_ptr = unsafe { $pbr$::upb_Array_Get(f.as_raw($pbi$::Private), i).msg_val } + .expect("upb_Array* element should not be NULL."); + $pb$::View::::new($pbi$::Private, msg_ptr) + } + + fn repeated_clear(mut f: $pb$::Mut<$pb$::Repeated>) { + // SAFETY: + // - `f.as_raw()` is a valid `upb_Array*`. + unsafe { + $pbr$::upb_Array_Resize(f.as_raw($pbi$::Private), 0, f.raw_arena($pbi$::Private)) + }; + } + fn repeated_push(mut f: $pb$::Mut<$pb$::Repeated>, v: $pb$::View) { + // SAFETY: + // - `v.raw_msg()` is a valid `const upb_Message*` with minitable `$minitable$`. + let msg_ptr = unsafe { + $pbr$::upb_Message_DeepClone( + v.raw_msg(), + std::ptr::addr_of!($minitable$), + f.raw_arena($pbi$::Private), + ) + }.expect("upb_Message_DeepClone failed."); + + // Append new default message to array. + // SAFETY: + // - `f.as_raw()` is a valid `upb_Array*`. + // - `msg_ptr` is a valid `upb_Message*`. + unsafe { + $pbr$::upb_Array_Append( + f.as_raw($pbi$::Private), + $pbr$::upb_MessageValue{msg_val: Some(msg_ptr)}, + f.raw_arena($pbi$::Private), + ); + }; + } + + fn repeated_copy_from( + src: $pb$::View<$pb$::Repeated>, + mut dest: $pb$::Mut<$pb$::Repeated>, + ) { + // SAFETY: + // - Elements of `src` and `dest` have message minitable `$minitable$`. + unsafe { + $pbr$::repeated_message_copy_from(src, dest, $std$::ptr::addr_of!($minitable$)); + } + } + } + )rs"); + return; + } + ABSL_LOG(FATAL) << "unreachable"; +} + } // namespace void GenerateRs(Context& ctx, const Descriptor& msg) { @@ -326,7 +519,8 @@ void GenerateRs(Context& ctx, const Descriptor& msg) { AccessorCase::MUT); } }}, - {"settable_impl", [&] { MessageSettableValue(ctx, msg); }}}, + {"settable_impl", [&] { MessageSettableValue(ctx, msg); }}, + {"repeated_impl", [&] { MessageProxiedInRepeated(ctx, msg); }}}, R"rs( #[allow(non_camel_case_types)] //~ TODO: Implement support for debug redaction @@ -389,6 +583,7 @@ void GenerateRs(Context& ctx, const Descriptor& msg) { } $settable_impl$ + $repeated_impl$ #[derive(Debug)] #[allow(dead_code)] @@ -530,13 +725,19 @@ void GenerateThunksCc(Context& ctx, const Descriptor& msg) { ctx.Emit( {{"abi", "\"C\""}, // Workaround for syntax highlight bug in VSCode. - {"Msg", msg.name()}, + {"Msg", RsSafeName(msg.name())}, {"QualifiedMsg", cpp::QualifiedClassName(&msg)}, {"new_thunk", ThunkName(ctx, msg, "new")}, {"delete_thunk", ThunkName(ctx, msg, "delete")}, {"serialize_thunk", ThunkName(ctx, msg, "serialize")}, {"deserialize_thunk", ThunkName(ctx, msg, "deserialize")}, {"copy_from_thunk", ThunkName(ctx, msg, "copy_from")}, + {"repeated_len_thunk", ThunkName(ctx, msg, "repeated_len")}, + {"repeated_get_thunk", ThunkName(ctx, msg, "repeated_get")}, + {"repeated_get_mut_thunk", ThunkName(ctx, msg, "repeated_get_mut")}, + {"repeated_add_thunk", ThunkName(ctx, msg, "repeated_add")}, + {"repeated_clear_thunk", ThunkName(ctx, msg, "repeated_clear")}, + {"repeated_copy_from_thunk", ThunkName(ctx, msg, "repeated_copy_from")}, {"nested_msg_thunks", [&] { for (int i = 0; i < msg.nested_type_count(); ++i) { @@ -575,6 +776,31 @@ void GenerateThunksCc(Context& ctx, const Descriptor& msg) { dst->CopyFrom(*src); } + size_t $repeated_len_thunk$(google::protobuf::RepeatedPtrField<$QualifiedMsg$>* field) { + return field->size(); + } + const $QualifiedMsg$& $repeated_get_thunk$( + google::protobuf::RepeatedPtrField<$QualifiedMsg$>* field, + size_t index) { + return field->Get(index); + } + $QualifiedMsg$* $repeated_get_mut_thunk$( + google::protobuf::RepeatedPtrField<$QualifiedMsg$>* field, + size_t index) { + return field->Mutable(index); + } + $QualifiedMsg$* $repeated_add_thunk$(google::protobuf::RepeatedPtrField<$QualifiedMsg$>* field) { + return field->Add(); + } + void $repeated_clear_thunk$(google::protobuf::RepeatedPtrField<$QualifiedMsg$>* field) { + field->Clear(); + } + void $repeated_copy_from_thunk$( + google::protobuf::RepeatedPtrField<$QualifiedMsg$>& dst, + const google::protobuf::RepeatedPtrField<$QualifiedMsg$>& src) { + dst = src; + } + $accessor_thunks$ $oneof_thunks$