From cbb3edd86d38866aca518a0b5d699bf3789abad7 Mon Sep 17 00:00:00 2001 From: Adam Cozzette Date: Fri, 18 Oct 2024 10:48:42 -0700 Subject: [PATCH] Rust C++: get all map fields onto a common implementation of ProxiedInMapValue This CL migrates messages, enums, and primitive types all onto the same blanket implementation of the `ProxiedInMapValue` trait. This gets us to the point where messages and enums no longer need to generate any significant amount of extra code just in case they might be used as a map value. There are a few big pieces to this: - I generalized the message-specific FFI endpoints in `rust/cpp_kernel/map.cc` to be able to additionally handle enums and primitive types as values. This mostly consisted of replacing `MessageLite*` parameters with a new `MapValue` tagged union. - On the Rust side, I added a new blanket implementation of `ProxiedInMapValue` in rust/cpp.rs. It relies on its value type to implement a new `CppMapTypeConversions` trait so that it can convert to and from the `MapValue` tagged union used for FFI. - In the Rust generated code, I deleted the generated `ProxiedInMapValue` implementations for messages and enums and replaced them with implementations of the `CppMapTypeConversions` trait. PiperOrigin-RevId: 687355817 --- rust/cpp.rs | 548 +++++++++++++----- rust/cpp_kernel/BUILD | 1 - rust/cpp_kernel/map.cc | 319 +++++----- rust/cpp_kernel/map.h | 122 ---- src/google/protobuf/compiler/rust/enum.cc | 116 +--- .../protobuf/compiler/rust/generator.cc | 1 - src/google/protobuf/compiler/rust/message.cc | 133 +---- src/google/protobuf/compiler/rust/naming.cc | 38 -- src/google/protobuf/map.h | 1 + 9 files changed, 617 insertions(+), 662 deletions(-) delete mode 100644 rust/cpp_kernel/map.h diff --git a/rust/cpp.rs b/rust/cpp.rs index 460910463c..ed4bfc45a1 100644 --- a/rust/cpp.rs +++ b/rust/cpp.rs @@ -762,48 +762,397 @@ impl UntypedMapIterator { } } -// This enum is used to pass some information about the key type of a map to -// C++. The main purpose is to indicate the size of the key so that we can -// determine the correct size and offset information of map entries on the C++ -// side. We also rely on it to indicate whether the key is a string or not. +// LINT.IfChange(map_ffi) #[doc(hidden)] #[repr(u8)] -pub enum MapKeyCategory { - OneByte, - FourBytes, - EightBytes, - StdString, +#[derive(Debug, PartialEq)] +pub enum MapValueTag { + Bool, + U32, + U64, + String, + Message, } +// For the purposes of FFI, we treat all numeric types of a given size the same +// way. For example, u32, i32, and f32 values are all represented as a u32. +// Likewise, u64, i64, and f64 values are all stored in a u64. #[doc(hidden)] -pub trait MapKey { - const CATEGORY: MapKeyCategory; +#[repr(C)] +pub union MapValueUnion { + pub b: bool, + pub u: u32, + pub uu: u64, + // Generally speaking, if s is set then it should not be None. However, we + // do set it to None in the special case where the MapValue is just a + // "prototype" (see below). In that scenario, we just want to indicate the + // value type without having to allocate a real C++ std::string. + pub s: Option, + pub m: RawMessage, +} + +// We use this tagged union to represent map values for the purposes of FFI. +#[doc(hidden)] +#[repr(C)] +pub struct MapValue { + pub tag: MapValueTag, + pub val: MapValueUnion, +} +// LINT.ThenChange(//depot/google3/third_party/protobuf/rust/cpp_kernel/map.cc: +// map_ffi) + +impl MapValue { + fn make_bool(b: bool) -> Self { + MapValue { tag: MapValueTag::Bool, val: MapValueUnion { b } } + } + + pub fn make_u32(u: u32) -> Self { + MapValue { tag: MapValueTag::U32, val: MapValueUnion { u } } + } + + fn make_u64(uu: u64) -> Self { + MapValue { tag: MapValueTag::U64, val: MapValueUnion { uu } } + } + + fn make_string(s: CppStdString) -> Self { + MapValue { tag: MapValueTag::String, val: MapValueUnion { s: Some(s) } } + } + + pub fn make_message(m: RawMessage) -> Self { + MapValue { tag: MapValueTag::Message, val: MapValueUnion { m } } + } +} + +pub trait CppMapTypeConversions: Proxied { + // We have a notion of a map value "prototype", which is a MapValue that + // contains just enough information to indicate the value type of the map. + // We need this on the C++ side to be able to determine size and offset + // information about the map entry. For messages, the prototype is + // the message default instance. For all other types, it is just a MapValue + // with the appropriate tag. + fn get_prototype() -> MapValue; + + fn to_map_value(self) -> MapValue; + + /// # Safety + /// - `value` must store the correct type for `Self`. If it is a string or + /// bytes, then it must not be None. If `Self` is a closed enum, then + /// `value` must store a valid value for that enum. If `Self` is a + /// message, then `value` must store a message of the same type. + /// - The value must be valid for `'a` lifetime. + unsafe fn from_map_value<'a>(value: MapValue) -> View<'a, Self>; +} + +impl CppMapTypeConversions for u32 { + fn get_prototype() -> MapValue { + MapValue::make_u32(0) + } + fn to_map_value(self) -> MapValue { + MapValue::make_u32(self) + } + unsafe fn from_map_value<'a>(value: MapValue) -> View<'a, Self> { + debug_assert_eq!(value.tag, MapValueTag::U32); + unsafe { value.val.u } + } +} + +impl CppMapTypeConversions for i32 { + fn get_prototype() -> MapValue { + MapValue::make_u32(0) + } + fn to_map_value(self) -> MapValue { + MapValue::make_u32(self as u32) + } + unsafe fn from_map_value<'a>(value: MapValue) -> View<'a, Self> { + debug_assert_eq!(value.tag, MapValueTag::U32); + unsafe { value.val.u as i32 } + } +} + +impl CppMapTypeConversions for u64 { + fn get_prototype() -> MapValue { + MapValue::make_u64(0) + } + fn to_map_value(self) -> MapValue { + MapValue::make_u64(self) + } + unsafe fn from_map_value<'a>(value: MapValue) -> View<'a, Self> { + debug_assert_eq!(value.tag, MapValueTag::U64); + unsafe { value.val.uu } + } +} + +impl CppMapTypeConversions for i64 { + fn get_prototype() -> MapValue { + MapValue::make_u64(0) + } + fn to_map_value(self) -> MapValue { + MapValue::make_u64(self as u64) + } + unsafe fn from_map_value<'a>(value: MapValue) -> View<'a, Self> { + debug_assert_eq!(value.tag, MapValueTag::U64); + unsafe { value.val.uu as i64 } + } +} + +impl CppMapTypeConversions for f32 { + fn get_prototype() -> MapValue { + MapValue::make_u32(0) + } + fn to_map_value(self) -> MapValue { + MapValue::make_u32(self.to_bits()) + } + unsafe fn from_map_value<'a>(value: MapValue) -> View<'a, Self> { + debug_assert_eq!(value.tag, MapValueTag::U32); + unsafe { Self::from_bits(value.val.u) } + } +} + +impl CppMapTypeConversions for f64 { + fn get_prototype() -> MapValue { + MapValue::make_u64(0) + } + fn to_map_value(self) -> MapValue { + MapValue::make_u64(self.to_bits()) + } + unsafe fn from_map_value<'a>(value: MapValue) -> View<'a, Self> { + debug_assert_eq!(value.tag, MapValueTag::U64); + unsafe { Self::from_bits(value.val.uu) } + } +} + +impl CppMapTypeConversions for bool { + fn get_prototype() -> MapValue { + MapValue::make_bool(false) + } + fn to_map_value(self) -> MapValue { + MapValue::make_bool(self) + } + unsafe fn from_map_value<'a>(value: MapValue) -> View<'a, Self> { + debug_assert_eq!(value.tag, MapValueTag::Bool); + unsafe { value.val.b } + } +} + +impl CppMapTypeConversions for ProtoString { + fn get_prototype() -> MapValue { + MapValue { tag: MapValueTag::String, val: MapValueUnion { s: None } } + } + + fn to_map_value(self) -> MapValue { + MapValue::make_string(protostr_into_cppstdstring(self)) + } + + unsafe fn from_map_value<'a>(value: MapValue) -> &'a ProtoStr { + debug_assert_eq!(value.tag, MapValueTag::String); + unsafe { + ProtoStr::from_utf8_unchecked( + ptrlen_to_str(proto2_rust_cpp_string_to_view(value.val.s.unwrap())).into(), + ) + } + } +} + +impl CppMapTypeConversions for ProtoBytes { + fn get_prototype() -> MapValue { + MapValue { tag: MapValueTag::String, val: MapValueUnion { s: None } } + } + + fn to_map_value(self) -> MapValue { + MapValue::make_string(protobytes_into_cppstdstring(self)) + } + + unsafe fn from_map_value<'a>(value: MapValue) -> &'a [u8] { + debug_assert_eq!(value.tag, MapValueTag::String); + unsafe { proto2_rust_cpp_string_to_view(value.val.s.unwrap()).as_ref() } + } +} + +// This trait encapsulates functionality that is specific to each map key type. +// We need this primarily so that we can call the appropriate FFI function for +// the key type. +#[doc(hidden)] +pub trait MapKey +where + Self: Proxied, +{ + type FfiKey; + + fn to_view<'a>(key: Self::FfiKey) -> View<'a, Self>; + + unsafe fn free(m: RawMap, prototype: MapValue); + + unsafe fn clear(m: RawMap, prototype: MapValue); + + unsafe fn insert(m: RawMap, key: View<'_, Self>, value: MapValue) -> bool; + + unsafe fn get( + m: RawMap, + prototype: MapValue, + key: View<'_, Self>, + value: *mut MapValue, + ) -> bool; + + unsafe fn iter_get( + iter: &mut UntypedMapIterator, + prototype: MapValue, + key: *mut Self::FfiKey, + value: *mut MapValue, + ); + + unsafe fn remove(m: RawMap, prototype: MapValue, key: View<'_, Self>) -> bool; } macro_rules! generate_map_key_impl { - ( $($key:ty, $category:expr;)* ) => { + ( $($key:ty, $mutable_ffi_key:ty, $to_ffi:expr, $from_ffi:expr;)* ) => { + paste! { $( impl MapKey for $key { - const CATEGORY: MapKeyCategory = $category; + type FfiKey = $mutable_ffi_key; + + #[inline] + fn to_view<'a>(key: Self::FfiKey) -> View<'a, Self> { + $from_ffi(key) + } + + #[inline] + unsafe fn free(m: RawMap, prototype: MapValue) { + unsafe { [< proto2_rust_map_free_ $key >](m, prototype) } + } + + #[inline] + unsafe fn clear(m: RawMap, prototype: MapValue) { + unsafe { [< proto2_rust_map_clear_ $key >](m, prototype) } + } + + #[inline] + unsafe fn insert( + m: RawMap, + key: View<'_, Self>, + value: MapValue, + ) -> bool { + unsafe { [< proto2_rust_map_insert_ $key >](m, $to_ffi(key), value) } + } + + #[inline] + unsafe fn get( + m: RawMap, + prototype: MapValue, + key: View<'_, Self>, + value: *mut MapValue, + ) -> bool { + unsafe { [< proto2_rust_map_get_ $key >](m, prototype, $to_ffi(key), value) } + } + + #[inline] + unsafe fn iter_get( + iter: &mut UntypedMapIterator, + prototype: MapValue, + key: *mut Self::FfiKey, + value: *mut MapValue, + ) { + unsafe { [< proto2_rust_map_iter_get_ $key >](iter, prototype, key, value) } + } + + #[inline] + unsafe fn remove(m: RawMap, prototype: MapValue, key: View<'_, Self>) -> bool { + unsafe { [< proto2_rust_map_remove_ $key >](m, prototype, $to_ffi(key)) } + } } )* + } } } -// LINT.IfChange(map_key_category) generate_map_key_impl!( - bool, MapKeyCategory::OneByte; - i32, MapKeyCategory::FourBytes; - u32, MapKeyCategory::FourBytes; - i64, MapKeyCategory::EightBytes; - u64, MapKeyCategory::EightBytes; - ProtoString, MapKeyCategory::StdString; + bool, bool, identity, identity; + i32, i32, identity, identity; + u32, u32, identity, identity; + i64, i64, identity, identity; + u64, u64, identity, identity; + ProtoString, PtrAndLen, str_to_ptrlen, ptrlen_to_str; ); -// LINT.ThenChange(//depot/google3/third_party/protobuf/rust/cpp_kernel/map.cc: -// map_key_category) + +impl ProxiedInMapValue for Value +where + Key: Proxied + MapKey, + Value: Proxied + CppMapTypeConversions, +{ + fn map_new(_private: Private) -> Map { + unsafe { Map::from_inner(Private, InnerMap::new(proto2_rust_map_new())) } + } + + unsafe fn map_free(_private: Private, map: &mut Map) { + unsafe { + Key::free(map.as_raw(Private), Self::get_prototype()); + } + } + + fn map_clear(mut map: MapMut) { + unsafe { + Key::clear(map.as_raw(Private), Self::get_prototype()); + } + } + + fn map_len(map: MapView) -> usize { + unsafe { proto2_rust_map_size(map.as_raw(Private)) } + } + + fn map_insert( + mut map: MapMut, + key: View<'_, Key>, + value: impl IntoProxied, + ) -> bool { + unsafe { Key::insert(map.as_raw(Private), key, value.into_proxied(Private).to_map_value()) } + } + + fn map_get<'a>(map: MapView<'a, Key, Self>, key: View<'_, Key>) -> Option> { + let mut value = std::mem::MaybeUninit::uninit(); + let found = unsafe { + Key::get(map.as_raw(Private), Self::get_prototype(), key, value.as_mut_ptr()) + }; + if !found { + return None; + } + unsafe { Some(Self::from_map_value(value.assume_init())) } + } + + fn map_remove(mut map: MapMut, key: View<'_, Key>) -> bool { + unsafe { Key::remove(map.as_raw(Private), Self::get_prototype(), key) } + } + + fn map_iter(map: MapView) -> MapIter { + // SAFETY: + // - The backing map for `map.as_raw` is valid for at least '_. + // - A View that is live for '_ guarantees the backing map is unmodified for '_. + // - The `iter` function produces an iterator that is valid for the key and + // value types, and live for at least '_. + unsafe { MapIter::from_raw(Private, proto2_rust_map_iter(map.as_raw(Private))) } + } + + fn map_iter_next<'a>( + iter: &mut MapIter<'a, Key, Self>, + ) -> Option<(View<'a, Key>, View<'a, Self>)> { + // SAFETY: + // - The `MapIter` API forbids the backing map from being mutated for 'a, and + // guarantees that it's the correct key and value types. + // - The thunk is safe to call as long as the iterator isn't at the end. + // - The thunk always writes to key and value fields and does not read. + // - The thunk does not increment the iterator. + unsafe { + iter.as_raw_mut(Private).next_unchecked::( + |iter, key, value| Key::iter_get(iter, Self::get_prototype(), key, value), + |ffi_key| Key::to_view(ffi_key), + |value| Self::from_map_value(value), + ) + } + } +} macro_rules! impl_map_primitives { (@impl $(($rust_type:ty, $cpp_type:ty) => [ + $free_thunk:ident, + $clear_thunk:ident, $insert_thunk:ident, $get_thunk:ident, $iter_get_thunk:ident, @@ -811,24 +1160,32 @@ macro_rules! impl_map_primitives { ]),* $(,)?) => { $( extern "C" { + pub fn $free_thunk( + m: RawMap, + prototype: MapValue, + ); + pub fn $clear_thunk( + m: RawMap, + prototype: MapValue, + ); pub fn $insert_thunk( m: RawMap, key: $cpp_type, - value: RawMessage, + value: MapValue, ) -> bool; pub fn $get_thunk( m: RawMap, - prototype: RawMessage, + prototype: MapValue, key: $cpp_type, - value: *mut RawMessage, + value: *mut MapValue, ) -> bool; pub fn $iter_get_thunk( iter: &mut UntypedMapIterator, - prototype: RawMessage, + prototype: MapValue, key: *mut $cpp_type, - value: *mut RawMessage, + value: *mut MapValue, ); - pub fn $remove_thunk(m: RawMap, prototype: RawMessage, key: $cpp_type) -> bool; + pub fn $remove_thunk(m: RawMap, prototype: MapValue, key: $cpp_type) -> bool; } )* }; @@ -836,6 +1193,8 @@ macro_rules! impl_map_primitives { paste!{ impl_map_primitives!(@impl $( ($rust_type, $cpp_type) => [ + [< proto2_rust_map_free_ $rust_type >], + [< proto2_rust_map_clear_ $rust_type >], [< proto2_rust_map_insert_ $rust_type >], [< proto2_rust_map_get_ $rust_type >], [< proto2_rust_map_iter_get_ $rust_type >], @@ -859,113 +1218,10 @@ extern "C" { fn proto2_rust_thunk_UntypedMapIterator_increment(iter: &mut UntypedMapIterator); pub fn proto2_rust_map_new() -> RawMap; - pub fn proto2_rust_map_free(m: RawMap, category: MapKeyCategory, prototype: RawMessage); - pub fn proto2_rust_map_clear(m: RawMap, category: MapKeyCategory, prototype: RawMessage); pub fn proto2_rust_map_size(m: RawMap) -> usize; pub fn proto2_rust_map_iter(m: RawMap) -> UntypedMapIterator; } -macro_rules! impl_ProxiedInMapValue_for_non_generated_value_types { - ($key_t:ty, $ffi_key_t:ty, $to_ffi_key:expr, $from_ffi_key:expr, for $($t:ty, $ffi_view_t:ty, $ffi_value_t:ty, $to_ffi_value:expr, $from_ffi_value:expr;)*) => { - paste! { $( - extern "C" { - pub fn [< proto2_rust_thunk_Map_ $key_t _ $t _new >]() -> RawMap; - pub fn [< proto2_rust_thunk_Map_ $key_t _ $t _free >](m: RawMap); - pub fn [< proto2_rust_thunk_Map_ $key_t _ $t _clear >](m: RawMap); - pub fn [< proto2_rust_thunk_Map_ $key_t _ $t _size >](m: RawMap) -> usize; - pub fn [< proto2_rust_thunk_Map_ $key_t _ $t _insert >](m: RawMap, key: $ffi_key_t, value: $ffi_value_t) -> bool; - pub fn [< proto2_rust_thunk_Map_ $key_t _ $t _get >](m: RawMap, key: $ffi_key_t, value: *mut $ffi_view_t) -> bool; - pub fn [< proto2_rust_thunk_Map_ $key_t _ $t _iter >](m: RawMap) -> UntypedMapIterator; - pub fn [< proto2_rust_thunk_Map_ $key_t _ $t _iter_get >](iter: &mut UntypedMapIterator, key: *mut $ffi_key_t, value: *mut $ffi_view_t); - pub fn [< proto2_rust_thunk_Map_ $key_t _ $t _remove >](m: RawMap, key: $ffi_key_t, value: *mut $ffi_view_t) -> bool; - } - - impl ProxiedInMapValue<$key_t> for $t { - fn map_new(_private: Private) -> Map<$key_t, Self> { - unsafe { - Map::from_inner( - Private, - InnerMap { - raw: [< proto2_rust_thunk_Map_ $key_t _ $t _new >](), - } - ) - } - } - - unsafe fn map_free(_private: Private, map: &mut Map<$key_t, Self>) { - // SAFETY: - // - `map.inner.raw` is a live `RawMap` - // - This function is only called once for `map` in `Drop`. - unsafe { [< proto2_rust_thunk_Map_ $key_t _ $t _free >](map.as_mut().as_raw(Private)); } - } - - - fn map_clear(mut map: MapMut<$key_t, Self>) { - unsafe { [< proto2_rust_thunk_Map_ $key_t _ $t _clear >](map.as_raw(Private)); } - } - - fn map_len(map: MapView<$key_t, Self>) -> usize { - unsafe { [< proto2_rust_thunk_Map_ $key_t _ $t _size >](map.as_raw(Private)) } - } - - fn map_insert(mut map: MapMut<$key_t, Self>, key: View<'_, $key_t>, value: impl IntoProxied) -> bool { - let ffi_key = $to_ffi_key(key); - let ffi_value = $to_ffi_value(value.into_proxied(Private)); - unsafe { [< proto2_rust_thunk_Map_ $key_t _ $t _insert >](map.as_raw(Private), ffi_key, ffi_value) } - } - - fn map_get<'a>(map: MapView<'a, $key_t, Self>, key: View<'_, $key_t>) -> Option> { - let ffi_key = $to_ffi_key(key); - let mut ffi_value = MaybeUninit::uninit(); - let found = unsafe { [< proto2_rust_thunk_Map_ $key_t _ $t _get >](map.as_raw(Private), ffi_key, ffi_value.as_mut_ptr()) }; - - if !found { - return None; - } - // SAFETY: if `found` is true, then the `ffi_value` was written to by `get`. - Some($from_ffi_value(unsafe { ffi_value.assume_init() })) - } - - fn map_remove(mut map: MapMut<$key_t, Self>, key: View<'_, $key_t>) -> bool { - let ffi_key = $to_ffi_key(key); - let mut ffi_value = MaybeUninit::uninit(); - unsafe { [< proto2_rust_thunk_Map_ $key_t _ $t _remove >](map.as_raw(Private), ffi_key, ffi_value.as_mut_ptr()) } - } - - fn map_iter(map: MapView<$key_t, Self>) -> MapIter<$key_t, Self> { - // SAFETY: - // - The backing map for `map.as_raw` is valid for at least '_. - // - A View that is live for '_ guarantees the backing map is unmodified for '_. - // - The `iter` function produces an iterator that is valid for the key - // and value types, and live for at least '_. - unsafe { - MapIter::from_raw( - Private, - [< proto2_rust_thunk_Map_ $key_t _ $t _iter >](map.as_raw(Private)) - ) - } - } - - fn map_iter_next<'a>(iter: &mut MapIter<'a, $key_t, Self>) -> Option<(View<'a, $key_t>, View<'a, Self>)> { - // SAFETY: - // - The `MapIter` API forbids the backing map from being mutated for 'a, - // and guarantees that it's the correct key and value types. - // - The thunk is safe to call as long as the iterator isn't at the end. - // - The thunk always writes to key and value fields and does not read. - // - The thunk does not increment the iterator. - unsafe { - iter.as_raw_mut(Private).next_unchecked::<$key_t, Self, _, _>( - |iter, key, value| { [< proto2_rust_thunk_Map_ $key_t _ $t _iter_get >](iter, key, value) }, - $from_ffi_key, - $from_ffi_value, - ) - } - } - } - )* } - } -} - fn str_to_ptrlen<'msg>(val: impl Into<&'msg ProtoStr>) -> PtrAndLen { val.into().as_bytes().into() } @@ -990,36 +1246,6 @@ fn ptrlen_to_bytes<'msg>(val: PtrAndLen) -> &'msg [u8] { unsafe { val.as_ref() } } -macro_rules! impl_ProxiedInMapValue_for_key_types { - ($($t:ty, $ffi_t:ty, $to_ffi_key:expr, $from_ffi_key:expr;)*) => { - paste! { - $( - impl_ProxiedInMapValue_for_non_generated_value_types!( - $t, $ffi_t, $to_ffi_key, $from_ffi_key, for - f32, f32, f32, identity, identity; - f64, f64, f64, identity, identity; - i32, i32, i32, identity, identity; - u32, u32, u32, identity, identity; - i64, i64, i64, identity, identity; - u64, u64, u64, identity, identity; - bool, bool, bool, identity, identity; - ProtoString, PtrAndLen, CppStdString, protostr_into_cppstdstring, ptrlen_to_str; - ProtoBytes, PtrAndLen, CppStdString, protobytes_into_cppstdstring, ptrlen_to_bytes; - ); - )* - } - } -} - -impl_ProxiedInMapValue_for_key_types!( - i32, i32, identity, identity; - u32, u32, identity, identity; - i64, i64, identity, identity; - u64, u64, identity, identity; - bool, bool, identity, identity; - ProtoString, PtrAndLen, str_to_ptrlen, ptrlen_to_str; -); - #[cfg(test)] mod tests { use super::*; diff --git a/rust/cpp_kernel/BUILD b/rust/cpp_kernel/BUILD index 041b0c7a87..b428fbaa26 100644 --- a/rust/cpp_kernel/BUILD +++ b/rust/cpp_kernel/BUILD @@ -15,7 +15,6 @@ cc_library( hdrs = [ "compare.h", "debug.h", - "map.h", "rust_alloc_for_cpp_api.h", "serialized_data.h", "strings.h", diff --git a/rust/cpp_kernel/map.cc b/rust/cpp_kernel/map.cc index 7925d2073c..9ff049138a 100644 --- a/rust/cpp_kernel/map.cc +++ b/rust/cpp_kernel/map.cc @@ -1,5 +1,6 @@ -#include "rust/cpp_kernel/map.h" +#include "google/protobuf/map.h" +#include #include #include #include @@ -7,7 +8,6 @@ #include #include "absl/log/absl_log.h" -#include "google/protobuf/map.h" #include "google/protobuf/message.h" #include "google/protobuf/message_lite.h" #include "rust/cpp_kernel/strings.h" @@ -17,6 +17,27 @@ namespace protobuf { namespace rust { namespace { +// LINT.IfChange(map_ffi) +enum class MapValueTag : uint8_t { + kBool, + kU32, + kU64, + kString, + kMessage, +}; + +struct MapValue { + MapValueTag tag; + union { + bool b; + uint32_t u32; + uint64_t u64; + std::string* s; + google::protobuf::MessageLite* message; + }; +}; +// LINT.ThenChange(//depot/google3/third_party/protobuf/rust/cpp.rs:map_ffi) + template struct FromViewType { using type = T; @@ -31,41 +52,85 @@ template using KeyMap = internal::KeyMapBase< internal::KeyForBase::type>>; -internal::MapNodeSizeInfoT GetSizeInfo(size_t key_size, - const google::protobuf::MessageLite* value) { +void GetSizeAndAlignment(MapValue value, uint16_t* size, uint8_t* alignment) { + switch (value.tag) { + case MapValueTag::kBool: + *size = sizeof(bool); + *alignment = alignof(bool); + break; + case MapValueTag::kU32: + *size = sizeof(uint32_t); + *alignment = alignof(uint32_t); + break; + case MapValueTag::kU64: + *size = sizeof(uint64_t); + *alignment = alignof(uint64_t); + break; + case MapValueTag::kString: + *size = sizeof(std::string); + *alignment = alignof(std::string); + break; + case MapValueTag::kMessage: + internal::RustMapHelper::GetSizeAndAlignment(value.message, size, + alignment); + break; + default: + ABSL_DLOG(FATAL) << "Unexpected value of MapValue"; + } +} + +internal::MapNodeSizeInfoT GetSizeInfo(size_t key_size, MapValue value) { // Each map node consists of a NodeBase followed by a std::pair. // We need to compute the offset of the value and the total size of the node. size_t node_and_key_size = sizeof(internal::NodeBase) + key_size; uint16_t value_size; uint8_t value_alignment; - internal::RustMapHelper::GetSizeAndAlignment(value, &value_size, - &value_alignment); + GetSizeAndAlignment(value, &value_size, &value_alignment); // Round node_and_key_size up to the nearest multiple of value_alignment. uint16_t offset = (((node_and_key_size - 1) / value_alignment) + 1) * value_alignment; - return internal::RustMapHelper::MakeSizeInfo(offset + value_size, offset); -} -template -internal::MapNodeSizeInfoT GetSizeInfo(const google::protobuf::MessageLite* value) { - return GetSizeInfo(sizeof(Key), value); + size_t overall_alignment = std::max(alignof(internal::NodeBase), + static_cast(value_alignment)); + // Round up size to nearest multiple of overall_alignment. + size_t overall_size = + (((offset + value_size - 1) / overall_alignment) + 1) * overall_alignment; + + return internal::RustMapHelper::MakeSizeInfo(overall_size, offset); } template void DestroyMapNode(internal::UntypedMapBase* m, internal::NodeBase* node, - internal::MapNodeSizeInfoT size_info) { + internal::MapNodeSizeInfoT size_info, + bool destroy_message) { if constexpr (std::is_same::value) { static_cast(node->GetVoidKey())->~basic_string(); } - internal::RustMapHelper::DestroyMessage( - static_cast(node->GetVoidValue(size_info))); + if (destroy_message) { + internal::RustMapHelper::DestroyMessage( + static_cast(node->GetVoidValue(size_info))); + } internal::RustMapHelper::DeallocNode(m, node, size_info); } +void InitializeMessageValue(void* raw_ptr, MessageLite* msg) { + MessageLite* new_msg = internal::RustMapHelper::PlacementNew(msg, raw_ptr); + auto* full_msg = DynamicCastMessage(new_msg); + + // If we are working with a full (non-lite) proto, we reflectively swap the + // value into place. Otherwise, we have to perform a copy. + if (full_msg != nullptr) { + full_msg->GetReflection()->Swap(full_msg, DynamicCastMessage(msg)); + } else { + new_msg->CheckTypeAndMergeFrom(*msg); + } + delete msg; +} + template -bool Insert(internal::UntypedMapBase* m, Key key, MessageLite* value) { +bool Insert(internal::UntypedMapBase* m, Key key, MapValue value) { internal::MapNodeSizeInfoT size_info = - GetSizeInfo::type>(value); + GetSizeInfo(sizeof(typename FromViewType::type), value); internal::NodeBase* node = internal::RustMapHelper::AllocNode(m, size_info); if constexpr (std::is_same::value) { new (node->GetVoidKey()) std::string(key.ptr, key.len); @@ -73,17 +138,26 @@ bool Insert(internal::UntypedMapBase* m, Key key, MessageLite* value) { *static_cast(node->GetVoidKey()) = key; } - MessageLite* new_msg = internal::RustMapHelper::PlacementNew( - value, node->GetVoidValue(size_info)); - auto* full_msg = DynamicCastMessage(new_msg); - - // If we are working with a full (non-lite) proto, we reflectively swap the - // value into place. Otherwise, we have to perform a copy. - if (full_msg != nullptr) { - full_msg->GetReflection()->Swap(full_msg, - DynamicCastMessage(value)); - } else { - new_msg->CheckTypeAndMergeFrom(*value); + void* value_ptr = node->GetVoidValue(size_info); + switch (value.tag) { + case MapValueTag::kBool: + *static_cast(value_ptr) = value.b; + break; + case MapValueTag::kU32: + *static_cast(value_ptr) = value.u32; + break; + case MapValueTag::kU64: + *static_cast(value_ptr) = value.u64; + break; + case MapValueTag::kString: + new (value_ptr) std::string(std::move(*value.s)); + delete value.s; + break; + case MapValueTag::kMessage: + InitializeMessageValue(value_ptr, value.message); + break; + default: + ABSL_DLOG(FATAL) << "Unexpected value of MapValue"; } node = internal::RustMapHelper::InsertOrReplaceNode( @@ -91,7 +165,7 @@ bool Insert(internal::UntypedMapBase* m, Key key, MessageLite* value) { if (node == nullptr) { return true; } - DestroyMapNode(m, node, size_info); + DestroyMapNode(m, node, size_info, value.tag == MapValueTag::kMessage); return false; } @@ -110,41 +184,63 @@ internal::RustMapHelper::NodeAndBucket FindHelper(Map* m, m, absl::string_view(key.ptr, key.len)); } +void PopulateMapValue(MapValueTag tag, void* data, MapValue& output) { + output.tag = tag; + switch (tag) { + case MapValueTag::kBool: + output.b = *static_cast(data); + break; + case MapValueTag::kU32: + output.u32 = *static_cast(data); + break; + case MapValueTag::kU64: + output.u64 = *static_cast(data); + break; + case MapValueTag::kString: + output.s = static_cast(data); + break; + case MapValueTag::kMessage: + output.message = static_cast(data); + break; + default: + ABSL_DLOG(FATAL) << "Unexpected MapValueTag"; + } +} + template -bool Get(internal::UntypedMapBase* m, const google::protobuf::MessageLite* prototype, - Key key, MessageLite** value) { +bool Get(internal::UntypedMapBase* m, MapValue prototype, Key key, + MapValue* value) { internal::MapNodeSizeInfoT size_info = - GetSizeInfo::type>(prototype); + GetSizeInfo(sizeof(typename FromViewType::type), prototype); auto* map_base = static_cast*>(m); internal::RustMapHelper::NodeAndBucket result = FindHelper(map_base, key); if (result.node == nullptr) { return false; } - *value = static_cast(result.node->GetVoidValue(size_info)); + PopulateMapValue(prototype.tag, result.node->GetVoidValue(size_info), *value); return true; } template -bool Remove(internal::UntypedMapBase* m, const google::protobuf::MessageLite* prototype, - Key key) { +bool Remove(internal::UntypedMapBase* m, MapValue prototype, Key key) { internal::MapNodeSizeInfoT size_info = - GetSizeInfo::type>(prototype); + GetSizeInfo(sizeof(typename FromViewType::type), prototype); auto* map_base = static_cast*>(m); internal::RustMapHelper::NodeAndBucket result = FindHelper(map_base, key); if (result.node == nullptr) { return false; } internal::RustMapHelper::EraseNoDestroy(map_base, result.bucket, result.node); - DestroyMapNode(m, result.node, size_info); + DestroyMapNode(m, result.node, size_info, + prototype.tag == MapValueTag::kMessage); return true; } template -void IterGet(const internal::UntypedMapIterator* iter, - const google::protobuf::MessageLite* prototype, Key* key, - MessageLite** value) { +void IterGet(const internal::UntypedMapIterator* iter, MapValue prototype, + Key* key, MapValue* value) { internal::MapNodeSizeInfoT size_info = - GetSizeInfo::type>(prototype); + GetSizeInfo(sizeof(typename FromViewType::type), prototype); internal::NodeBase* node = iter->node_; if constexpr (std::is_same::value) { const std::string* s = static_cast(node->GetVoidKey()); @@ -152,42 +248,35 @@ void IterGet(const internal::UntypedMapIterator* iter, } else { *key = *static_cast(node->GetVoidKey()); } - *value = static_cast(node->GetVoidValue(size_info)); + PopulateMapValue(prototype.tag, node->GetVoidValue(size_info), *value); } -// LINT.IfChange(map_key_category) -enum class MapKeyCategory : uint8_t { - kOneByte = 0, - kFourBytes = 1, - kEightBytes = 2, - kStdString = 3, -}; -// LINT.ThenChange(//depot/google3/third_party/protobuf/rust/cpp.rs:map_key_category) - -size_t KeySize(MapKeyCategory category) { - switch (category) { - case MapKeyCategory::kOneByte: - return 1; - case MapKeyCategory::kFourBytes: - return 4; - case MapKeyCategory::kEightBytes: - return 8; - case MapKeyCategory::kStdString: - return sizeof(std::string); - default: - ABSL_DLOG(FATAL) << "Unexpected value of MapKeyCategory enum"; +// Returns the size of the key in the map entry, given the key used for FFI. +// The map entry key and FFI key are always the same, except in the case of +// string and bytes. +template +size_t KeySize() { + if constexpr (std::is_same::value) { + return sizeof(std::string); + } else { + return sizeof(Key); } } -void ClearMap(internal::UntypedMapBase* m, MapKeyCategory category, - bool reset_table, const google::protobuf::MessageLite* prototype) { - internal::MapNodeSizeInfoT size_info = - GetSizeInfo(KeySize(category), prototype); +template +void ClearMap(internal::UntypedMapBase* m, bool reset_table, + MapValue prototype) { + internal::MapNodeSizeInfoT size_info = GetSizeInfo(KeySize(), prototype); if (internal::RustMapHelper::IsGlobalEmptyTable(m)) return; - uint8_t bits = internal::RustMapHelper::kValueIsProto; - if (category == MapKeyCategory::kStdString) { + uint8_t bits = 0; + if constexpr (std::is_same::value) { bits |= internal::RustMapHelper::kKeyIsString; } + if (prototype.tag == MapValueTag::kString) { + bits |= internal::RustMapHelper::kValueIsString; + } else if (prototype.tag == MapValueTag::kMessage) { + bits |= internal::RustMapHelper::kValueIsProto; + } internal::RustMapHelper::ClearTable( m, internal::RustMapHelper::ClearInput{size_info, bits, reset_table, /* destroy_node = */ nullptr}); @@ -209,19 +298,6 @@ google::protobuf::internal::UntypedMapBase* proto2_rust_map_new() { return new google::protobuf::internal::UntypedMapBase(/* arena = */ nullptr); } -void proto2_rust_map_free(google::protobuf::internal::UntypedMapBase* m, - google::protobuf::rust::MapKeyCategory category, - const google::protobuf::MessageLite* prototype) { - google::protobuf::rust::ClearMap(m, category, /* reset_table = */ false, prototype); - delete m; -} - -void proto2_rust_map_clear(google::protobuf::internal::UntypedMapBase* m, - google::protobuf::rust::MapKeyCategory category, - const google::protobuf::MessageLite* prototype) { - google::protobuf::rust::ClearMap(m, category, /* reset_table = */ true, prototype); -} - size_t proto2_rust_map_size(google::protobuf::internal::UntypedMapBase* m) { return m->size(); } @@ -231,31 +307,39 @@ google::protobuf::internal::UntypedMapIterator proto2_rust_map_iter( return m->begin(); } -#define DEFINE_KEY_SPECIFIC_MAP_OPERATIONS(cpp_type, suffix) \ - bool proto2_rust_map_insert_##suffix(google::protobuf::internal::UntypedMapBase* m, \ - cpp_type key, \ - google::protobuf::MessageLite* value) { \ - return google::protobuf::rust::Insert(m, key, value); \ - } \ - \ - bool proto2_rust_map_get_##suffix(google::protobuf::internal::UntypedMapBase* m, \ - const google::protobuf::MessageLite* prototype, \ - cpp_type key, \ - google::protobuf::MessageLite** value) { \ - return google::protobuf::rust::Get(m, prototype, key, value); \ - } \ - \ - bool proto2_rust_map_remove_##suffix(google::protobuf::internal::UntypedMapBase* m, \ - const google::protobuf::MessageLite* prototype, \ - cpp_type key) { \ - return google::protobuf::rust::Remove(m, prototype, key); \ - } \ - \ - void proto2_rust_map_iter_get_##suffix( \ - const google::protobuf::internal::UntypedMapIterator* iter, \ - const google::protobuf::MessageLite* prototype, cpp_type* key, \ - google::protobuf::MessageLite** value) { \ - return google::protobuf::rust::IterGet(iter, prototype, key, value); \ +#define DEFINE_KEY_SPECIFIC_MAP_OPERATIONS(cpp_type, suffix) \ + void proto2_rust_map_free_##suffix(google::protobuf::internal::UntypedMapBase* m, \ + google::protobuf::rust::MapValue prototype) { \ + google::protobuf::rust::ClearMap(m, /* reset_table = */ false, prototype); \ + delete m; \ + } \ + void proto2_rust_map_clear_##suffix(google::protobuf::internal::UntypedMapBase* m, \ + google::protobuf::rust::MapValue prototype) { \ + google::protobuf::rust::ClearMap(m, /* reset_table = */ true, prototype); \ + } \ + bool proto2_rust_map_insert_##suffix(google::protobuf::internal::UntypedMapBase* m, \ + cpp_type key, \ + google::protobuf::rust::MapValue value) { \ + return google::protobuf::rust::Insert(m, key, value); \ + } \ + \ + bool proto2_rust_map_get_##suffix( \ + google::protobuf::internal::UntypedMapBase* m, google::protobuf::rust::MapValue prototype, \ + cpp_type key, google::protobuf::rust::MapValue* value) { \ + return google::protobuf::rust::Get(m, prototype, key, value); \ + } \ + \ + bool proto2_rust_map_remove_##suffix(google::protobuf::internal::UntypedMapBase* m, \ + google::protobuf::rust::MapValue prototype, \ + cpp_type key) { \ + return google::protobuf::rust::Remove(m, prototype, key); \ + } \ + \ + void proto2_rust_map_iter_get_##suffix( \ + const google::protobuf::internal::UntypedMapIterator* iter, \ + google::protobuf::rust::MapValue prototype, cpp_type* key, \ + google::protobuf::rust::MapValue* value) { \ + return google::protobuf::rust::IterGet(iter, prototype, key, value); \ } DEFINE_KEY_SPECIFIC_MAP_OPERATIONS(int32_t, i32) @@ -265,27 +349,4 @@ DEFINE_KEY_SPECIFIC_MAP_OPERATIONS(uint64_t, u64) DEFINE_KEY_SPECIFIC_MAP_OPERATIONS(bool, bool) DEFINE_KEY_SPECIFIC_MAP_OPERATIONS(google::protobuf::rust::PtrAndLen, ProtoString) -__PB_RUST_EXPOSE_SCALAR_MAP_METHODS_FOR_VALUE_TYPE(int32_t, i32, int32_t, - int32_t, value, cpp_value); -__PB_RUST_EXPOSE_SCALAR_MAP_METHODS_FOR_VALUE_TYPE(uint32_t, u32, uint32_t, - uint32_t, value, cpp_value); -__PB_RUST_EXPOSE_SCALAR_MAP_METHODS_FOR_VALUE_TYPE(float, f32, float, float, - value, cpp_value); -__PB_RUST_EXPOSE_SCALAR_MAP_METHODS_FOR_VALUE_TYPE(double, f64, double, double, - value, cpp_value); -__PB_RUST_EXPOSE_SCALAR_MAP_METHODS_FOR_VALUE_TYPE(bool, bool, bool, bool, - value, cpp_value); -__PB_RUST_EXPOSE_SCALAR_MAP_METHODS_FOR_VALUE_TYPE(uint64_t, u64, uint64_t, - uint64_t, value, cpp_value); -__PB_RUST_EXPOSE_SCALAR_MAP_METHODS_FOR_VALUE_TYPE(int64_t, i64, int64_t, - int64_t, value, cpp_value); -__PB_RUST_EXPOSE_SCALAR_MAP_METHODS_FOR_VALUE_TYPE( - std::string, ProtoBytes, google::protobuf::rust::PtrAndLen, std::string*, - std::move(*value), - (google::protobuf::rust::PtrAndLen{cpp_value.data(), cpp_value.size()})); -__PB_RUST_EXPOSE_SCALAR_MAP_METHODS_FOR_VALUE_TYPE( - std::string, ProtoString, google::protobuf::rust::PtrAndLen, std::string*, - std::move(*value), - (google::protobuf::rust::PtrAndLen{cpp_value.data(), cpp_value.size()})); - } // extern "C" diff --git a/rust/cpp_kernel/map.h b/rust/cpp_kernel/map.h deleted file mode 100644 index 8e9881f446..0000000000 --- a/rust/cpp_kernel/map.h +++ /dev/null @@ -1,122 +0,0 @@ -#ifndef GOOGLE_PROTOBUF_RUST_CPP_KERNEL_MAP_H__ -#define GOOGLE_PROTOBUF_RUST_CPP_KERNEL_MAP_H__ - -#include -#include - -#include "google/protobuf/map.h" -#include "google/protobuf/message_lite.h" -#include "rust/cpp_kernel/strings.h" - -namespace google { -namespace protobuf { -namespace rust { - -// String and bytes values are passed across the FFI boundary as owned raw -// pointers when we do map insertions. Unlike other types, they have to be -// explicitly deleted. This MakeCleanup() helper does nothing by default, but -// for std::string pointers it returns a std::unique_ptr to take ownership of -// the raw pointer. -template -auto MakeCleanup(T value) { - if constexpr (std::is_same::value) { - return std::unique_ptr(value); - } else { - return 0; - } -} - -} // namespace rust -} // namespace protobuf -} // namespace google - -// Defines concrete thunks to access typed map methods from Rust. -#define __PB_RUST_EXPOSE_SCALAR_MAP_METHODS( \ - key_ty, rust_key_ty, ffi_key_ty, to_cpp_key, to_ffi_key, value_ty, \ - rust_value_ty, ffi_view_ty, ffi_value_ty, to_cpp_value, to_ffi_value) \ - google::protobuf::Map* \ - proto2_rust_thunk_Map_##rust_key_ty##_##rust_value_ty##_new() { \ - return new google::protobuf::Map(); \ - } \ - void proto2_rust_thunk_Map_##rust_key_ty##_##rust_value_ty##_free( \ - google::protobuf::Map* m) { \ - delete m; \ - } \ - void proto2_rust_thunk_Map_##rust_key_ty##_##rust_value_ty##_clear( \ - google::protobuf::Map* m) { \ - m->clear(); \ - } \ - size_t proto2_rust_thunk_Map_##rust_key_ty##_##rust_value_ty##_size( \ - const google::protobuf::Map* m) { \ - return m->size(); \ - } \ - bool proto2_rust_thunk_Map_##rust_key_ty##_##rust_value_ty##_insert( \ - google::protobuf::Map* m, ffi_key_ty key, ffi_value_ty value) { \ - auto cleanup = google::protobuf::rust::MakeCleanup(value); \ - (void)cleanup; \ - auto iter_and_inserted = m->try_emplace(to_cpp_key, to_cpp_value); \ - if (!iter_and_inserted.second) { \ - iter_and_inserted.first->second = to_cpp_value; \ - } \ - return iter_and_inserted.second; \ - } \ - bool proto2_rust_thunk_Map_##rust_key_ty##_##rust_value_ty##_get( \ - const google::protobuf::Map* m, ffi_key_ty key, \ - ffi_view_ty* value) { \ - auto cpp_key = to_cpp_key; \ - auto it = m->find(cpp_key); \ - if (it == m->end()) { \ - return false; \ - } \ - auto& cpp_value = it->second; \ - *value = to_ffi_value; \ - return true; \ - } \ - google::protobuf::internal::UntypedMapIterator \ - proto2_rust_thunk_Map_##rust_key_ty##_##rust_value_ty##_iter( \ - const google::protobuf::Map* m) { \ - return google::protobuf::internal::UntypedMapIterator::FromTyped(m->cbegin()); \ - } \ - void proto2_rust_thunk_Map_##rust_key_ty##_##rust_value_ty##_iter_get( \ - const google::protobuf::internal::UntypedMapIterator* iter, ffi_key_ty* key, \ - ffi_view_ty* value) { \ - auto typed_iter = \ - iter->ToTyped::const_iterator>(); \ - const auto& cpp_key = typed_iter->first; \ - const auto& cpp_value = typed_iter->second; \ - *key = to_ffi_key; \ - *value = to_ffi_value; \ - } \ - bool proto2_rust_thunk_Map_##rust_key_ty##_##rust_value_ty##_remove( \ - google::protobuf::Map* m, ffi_key_ty key, ffi_view_ty* value) { \ - auto cpp_key = to_cpp_key; \ - auto num_removed = m->erase(cpp_key); \ - return num_removed > 0; \ - } - -// Defines the map thunks for all supported key types for a given value type. -#define __PB_RUST_EXPOSE_SCALAR_MAP_METHODS_FOR_VALUE_TYPE( \ - value_ty, rust_value_ty, ffi_view_ty, ffi_value_ty, to_cpp_value, \ - to_ffi_value) \ - __PB_RUST_EXPOSE_SCALAR_MAP_METHODS( \ - int32_t, i32, int32_t, key, cpp_key, value_ty, rust_value_ty, \ - ffi_view_ty, ffi_value_ty, to_cpp_value, to_ffi_value); \ - __PB_RUST_EXPOSE_SCALAR_MAP_METHODS( \ - uint32_t, u32, uint32_t, key, cpp_key, value_ty, rust_value_ty, \ - ffi_view_ty, ffi_value_ty, to_cpp_value, to_ffi_value); \ - __PB_RUST_EXPOSE_SCALAR_MAP_METHODS( \ - bool, bool, bool, key, cpp_key, value_ty, rust_value_ty, ffi_view_ty, \ - ffi_value_ty, to_cpp_value, to_ffi_value); \ - __PB_RUST_EXPOSE_SCALAR_MAP_METHODS( \ - uint64_t, u64, uint64_t, key, cpp_key, value_ty, rust_value_ty, \ - ffi_view_ty, ffi_value_ty, to_cpp_value, to_ffi_value); \ - __PB_RUST_EXPOSE_SCALAR_MAP_METHODS( \ - int64_t, i64, int64_t, key, cpp_key, value_ty, rust_value_ty, \ - ffi_view_ty, ffi_value_ty, to_cpp_value, to_ffi_value); \ - __PB_RUST_EXPOSE_SCALAR_MAP_METHODS( \ - std::string, ProtoString, google::protobuf::rust::PtrAndLen, \ - std::string(key.ptr, key.len), \ - (google::protobuf::rust::PtrAndLen{cpp_key.data(), cpp_key.size()}), value_ty, \ - rust_value_ty, ffi_view_ty, ffi_value_ty, to_cpp_value, to_ffi_value); - -#endif // GOOGLE_PROTOBUF_RUST_CPP_KERNEL_MAP_H__ diff --git a/src/google/protobuf/compiler/rust/enum.cc b/src/google/protobuf/compiler/rust/enum.cc index e405d4c9d1..8e0906bcc6 100644 --- a/src/google/protobuf/compiler/rust/enum.cc +++ b/src/google/protobuf/compiler/rust/enum.cc @@ -45,106 +45,26 @@ std::vector> EnumValuesInput( return result; } -void EnumProxiedInMapValue(Context& ctx, const EnumDescriptor& desc) { +void TypeConversions(Context& ctx, const EnumDescriptor& desc) { switch (ctx.opts().kernel) { case Kernel::kCpp: - for (const auto& t : kMapKeyTypes) { - ctx.Emit( - {{"map_new_thunk", RawMapThunk(ctx, desc, t.thunk_ident, "new")}, - {"map_free_thunk", RawMapThunk(ctx, desc, t.thunk_ident, "free")}, - {"map_clear_thunk", - RawMapThunk(ctx, desc, t.thunk_ident, "clear")}, - {"map_size_thunk", RawMapThunk(ctx, desc, t.thunk_ident, "size")}, - {"map_insert_thunk", - RawMapThunk(ctx, desc, t.thunk_ident, "insert")}, - {"map_get_thunk", RawMapThunk(ctx, desc, t.thunk_ident, "get")}, - {"map_remove_thunk", - RawMapThunk(ctx, desc, t.thunk_ident, "remove")}, - {"map_iter_thunk", RawMapThunk(ctx, desc, t.thunk_ident, "iter")}, - {"map_iter_get_thunk", - RawMapThunk(ctx, desc, t.thunk_ident, "iter_get")}, - {"to_ffi_key_expr", t.rs_to_ffi_key_expr}, - io::Printer::Sub("ffi_key_t", [&] { ctx.Emit(t.rs_ffi_key_t); }) - .WithSuffix(""), - io::Printer::Sub("key_t", [&] { ctx.Emit(t.rs_key_t); }) - .WithSuffix(""), - io::Printer::Sub("from_ffi_key_expr", - [&] { ctx.Emit(t.rs_from_ffi_key_expr); }) - .WithSuffix("")}, - R"rs( - impl $pb$::ProxiedInMapValue<$key_t$> for $name$ { - fn map_new(_private: $pbi$::Private) -> $pb$::Map<$key_t$, Self> { - unsafe { - $pb$::Map::from_inner( - $pbi$::Private, - $pbr$::InnerMap::new($pbr$::$map_new_thunk$()) - ) - } - } - - unsafe fn map_free(_private: $pbi$::Private, map: &mut $pb$::Map<$key_t$, Self>) { - unsafe { $pbr$::$map_free_thunk$(map.as_raw($pbi$::Private)); } - } - - fn map_clear(mut map: $pb$::MapMut<$key_t$, Self>) { - unsafe { $pbr$::$map_clear_thunk$(map.as_raw($pbi$::Private)); } - } - - fn map_len(map: $pb$::MapView<$key_t$, Self>) -> usize { - unsafe { $pbr$::$map_size_thunk$(map.as_raw($pbi$::Private)) } - } - - fn map_insert(mut map: $pb$::MapMut<$key_t$, Self>, key: $pb$::View<'_, $key_t$>, value: impl $pb$::IntoProxied) -> bool { - unsafe { $pbr$::$map_insert_thunk$(map.as_raw($pbi$::Private), $to_ffi_key_expr$, value.into_proxied($pbi$::Private).0) } - } - - fn map_get<'a>(map: $pb$::MapView<'a, $key_t$, Self>, key: $pb$::View<'_, $key_t$>) -> $Option$<$pb$::View<'a, Self>> { - let key = $to_ffi_key_expr$; - let mut value = $std$::mem::MaybeUninit::uninit(); - let found = unsafe { $pbr$::$map_get_thunk$(map.as_raw($pbi$::Private), key, value.as_mut_ptr()) }; - if !found { - return None; - } - Some(unsafe { $name$(value.assume_init()) }) - } - - fn map_remove(mut map: $pb$::MapMut<$key_t$, Self>, key: $pb$::View<'_, $key_t$>) -> bool { - let mut value = $std$::mem::MaybeUninit::uninit(); - unsafe { $pbr$::$map_remove_thunk$(map.as_raw($pbi$::Private), $to_ffi_key_expr$, value.as_mut_ptr()) } - } + ctx.Emit( + R"rs( + impl $pbr$::CppMapTypeConversions for $name$ { + fn get_prototype() -> $pbr$::MapValue { + Self::to_map_value(Self::default()) + } - fn map_iter(map: $pb$::MapView<$key_t$, Self>) -> $pb$::MapIter<$key_t$, Self> { - // SAFETY: - // - The backing map for `map.as_raw` is valid for at least '_. - // - A View that is live for '_ guarantees the backing map is unmodified for '_. - // - The `iter` function produces an iterator that is valid for the key - // and value types, and live for at least '_. - unsafe { - $pb$::MapIter::from_raw( - $pbi$::Private, - $pbr$::$map_iter_thunk$(map.as_raw($pbi$::Private)) - ) - } - } + fn to_map_value(self) -> $pbr$::MapValue { + $pbr$::MapValue::make_u32(self.0 as u32) + } - fn map_iter_next<'a>(iter: &mut $pb$::MapIter<'a, $key_t$, Self>) -> $Option$<($pb$::View<'a, $key_t$>, $pb$::View<'a, Self>)> { - // SAFETY: - // - The `MapIter` API forbids the backing map from being mutated for 'a, - // and guarantees that it's the correct key and value types. - // - The thunk is safe to call as long as the iterator isn't at the end. - // - The thunk always writes to key and value fields and does not read. - // - The thunk does not increment the iterator. - unsafe { - iter.as_raw_mut($pbi$::Private).next_unchecked::<$key_t$, Self, _, _>( - |iter, key, value| { $pbr$::$map_iter_get_thunk$(iter, key, value) }, - |ffi_key| $from_ffi_key_expr$, - |value| $name$(value), - ) - } - } - } - )rs"); - } + unsafe fn from_map_value<'a>(value: $pbr$::MapValue) -> $pb$::View<'a, Self> { + debug_assert_eq!(value.tag, $pbr$::MapValueTag::U32); + $name$(unsafe { value.val.u as i32 }) + } + } + )rs"); return; case Kernel::kUpb: ctx.Emit(R"rs( @@ -277,7 +197,7 @@ void GenerateEnumDefinition(Context& ctx, const EnumDescriptor& desc) { )rs"); } }}, - {"impl_proxied_in_map", [&] { EnumProxiedInMapValue(ctx, desc); }}, + {"type_conversions_impl", [&] { TypeConversions(ctx, desc); }}, }, R"rs( #[repr(transparent)] @@ -411,7 +331,7 @@ void GenerateEnumDefinition(Context& ctx, const EnumDescriptor& desc) { } } - $impl_proxied_in_map$ + $type_conversions_impl$ )rs"); } diff --git a/src/google/protobuf/compiler/rust/generator.cc b/src/google/protobuf/compiler/rust/generator.cc index ecadf48fb0..20b19e8135 100644 --- a/src/google/protobuf/compiler/rust/generator.cc +++ b/src/google/protobuf/compiler/rust/generator.cc @@ -225,7 +225,6 @@ bool RustGenerator::Generate(const FileDescriptor* file, #include "google/protobuf/map.h" #include "google/protobuf/repeated_field.h" #include "google/protobuf/repeated_ptr_field.h" -#include "rust/cpp_kernel/map.h" #include "rust/cpp_kernel/serialized_data.h" #include "rust/cpp_kernel/strings.h" )cc"); diff --git a/src/google/protobuf/compiler/rust/message.cc b/src/google/protobuf/compiler/rust/message.cc index 509b8a0da5..d0600670e4 100644 --- a/src/google/protobuf/compiler/rust/message.cc +++ b/src/google/protobuf/compiler/rust/message.cc @@ -573,120 +573,29 @@ void MessageProxiedInRepeated(Context& ctx, const Descriptor& msg) { ABSL_LOG(FATAL) << "unreachable"; } -void MessageProxiedInMapValue(Context& ctx, const Descriptor& msg) { +void TypeConversions(Context& ctx, const Descriptor& msg) { switch (ctx.opts().kernel) { case Kernel::kCpp: - for (const auto& t : kMapKeyTypes) { - ctx.Emit( - {{"map_insert", - absl::StrCat("proto2_rust_map_insert_", t.thunk_ident)}, - {"map_remove", - absl::StrCat("proto2_rust_map_remove_", t.thunk_ident)}, - {"map_get", absl::StrCat("proto2_rust_map_get_", t.thunk_ident)}, - {"map_iter_get", - absl::StrCat("proto2_rust_map_iter_get_", t.thunk_ident)}, - {"key_expr", t.rs_to_ffi_key_expr}, - {"key_is_string", - t.thunk_ident == "ProtoString" ? "true" : "false"}, - io::Printer::Sub("ffi_key_t", [&] { ctx.Emit(t.rs_ffi_key_t); }) - .WithSuffix(""), - io::Printer::Sub("key_t", [&] { ctx.Emit(t.rs_key_t); }) - .WithSuffix(""), - io::Printer::Sub("from_ffi_key_expr", - [&] { ctx.Emit(t.rs_from_ffi_key_expr); }) - .WithSuffix("")}, - R"rs( - impl $pb$::ProxiedInMapValue<$key_t$> for $Msg$ { - fn map_new(_private: $pbi$::Private) -> $pb$::Map<$key_t$, Self> { - unsafe { - $pb$::Map::from_inner( - $pbi$::Private, - $pbr$::InnerMap::new($pbr$::proto2_rust_map_new()) - ) - } - } - - unsafe fn map_free(_private: $pbi$::Private, map: &mut $pb$::Map<$key_t$, Self>) { - use $pbr$::MapKey; - unsafe { $pbr$::proto2_rust_map_free(map.as_raw($pbi$::Private), $key_t$::CATEGORY, <$pb$::View::<$Msg$> as std::default::Default>::default().raw_msg()); } - } - - fn map_clear(mut map: $pb$::MapMut<$key_t$, Self>) { - use $pbr$::MapKey; - unsafe { $pbr$::proto2_rust_map_clear(map.as_raw($pbi$::Private), $key_t$::CATEGORY, <$pb$::View::<$Msg$> as std::default::Default>::default().raw_msg()); } - } - - fn map_len(map: $pb$::MapView<$key_t$, Self>) -> usize { - unsafe { $pbr$::proto2_rust_map_size(map.as_raw($pbi$::Private)) } - } - - fn map_insert(mut map: $pb$::MapMut<$key_t$, Self>, key: $pb$::View<'_, $key_t$>, value: impl $pb$::IntoProxied) -> bool { - unsafe { - $pbr$::$map_insert$( - map.as_raw($pbi$::Private), - $key_expr$, - value.into_proxied($pbi$::Private).raw_msg()) - } - } - - fn map_get<'a>(map: $pb$::MapView<'a, $key_t$, Self>, key: $pb$::View<'_, $key_t$>) -> $Option$<$pb$::View<'a, Self>> { - let key = $key_expr$; - let mut value = $std$::mem::MaybeUninit::uninit(); - let found = unsafe { - $pbr$::$map_get$( - map.as_raw($pbi$::Private), - <$pb$::View::<$Msg$> as std::default::Default>::default().raw_msg(), - key, - value.as_mut_ptr()) - }; - if !found { - return None; - } - Some($Msg$View::new($pbi$::Private, unsafe { value.assume_init() })) - } - - fn map_remove(mut map: $pb$::MapMut<$key_t$, Self>, key: $pb$::View<'_, $key_t$>) -> bool { - unsafe { - $pbr$::$map_remove$( - map.as_raw($pbi$::Private), - <$pb$::View::<$Msg$> as std::default::Default>::default().raw_msg(), - $key_expr$) - } - } + ctx.Emit( + R"rs( + impl $pbr$::CppMapTypeConversions for $Msg$ { + fn get_prototype() -> $pbr$::MapValue { + $pbr$::MapValue::make_message(<$Msg$View as $std$::default::Default>::default().raw_msg()) + } - fn map_iter(map: $pb$::MapView<$key_t$, Self>) -> $pb$::MapIter<$key_t$, Self> { - // SAFETY: - // - The backing map for `map.as_raw` is valid for at least '_. - // - A View that is live for '_ guarantees the backing map is unmodified for '_. - // - The `iter` function produces an iterator that is valid for the key - // and value types, and live for at least '_. - unsafe { - $pb$::MapIter::from_raw( - $pbi$::Private, - $pbr$::proto2_rust_map_iter(map.as_raw($pbi$::Private)) - ) - } - } + fn to_map_value(self) -> $pbr$::MapValue { + use $pb$::OwnedMessageInterop; + $pbr$::MapValue::make_message(unsafe { + $NonNull$::new_unchecked(self.__unstable_leak_raw_message() as *mut _) + }) + } - fn map_iter_next<'a>(iter: &mut $pb$::MapIter<'a, $key_t$, Self>) -> $Option$<($pb$::View<'a, $key_t$>, $pb$::View<'a, Self>)> { - // SAFETY: - // - The `MapIter` API forbids the backing map from being mutated for 'a, - // and guarantees that it's the correct key and value types. - // - The thunk is safe to call as long as the iterator isn't at the end. - // - The thunk always writes to key and value fields and does not read. - // - The thunk does not increment the iterator. - unsafe { - iter.as_raw_mut($pbi$::Private).next_unchecked::<$key_t$, Self, _, _>( - |iter, key, value| { $pbr$::$map_iter_get$( - iter, <$pb$::View::<$Msg$> as std::default::Default>::default().raw_msg(), key, value) }, - |ffi_key| $from_ffi_key_expr$, - |raw_msg| $Msg$View::new($pbi$::Private, raw_msg) - ) - } - } - } - )rs"); - } + unsafe fn from_map_value<'b>(value: $pbr$::MapValue) -> $Msg$View<'b> { + debug_assert_eq!(value.tag, $pbr$::MapValueTag::Message); + unsafe { $Msg$View::new($pbi$::Private, value.val.m) } + } + } + )rs"); return; case Kernel::kUpb: ctx.Emit( @@ -845,7 +754,7 @@ void GenerateRs(Context& ctx, const Descriptor& msg) { {"upb_generated_message_trait_impls", [&] { UpbGeneratedMessageTraitImpls(ctx, msg); }}, {"repeated_impl", [&] { MessageProxiedInRepeated(ctx, msg); }}, - {"map_value_impl", [&] { MessageProxiedInMapValue(ctx, msg); }}, + {"type_conversions_impl", [&] { TypeConversions(ctx, msg); }}, {"unwrap_upb", [&] { if (ctx.is_upb()) { @@ -1008,7 +917,7 @@ void GenerateRs(Context& ctx, const Descriptor& msg) { $into_proxied_impl$ $repeated_impl$ - $map_value_impl$ + $type_conversions_impl$ #[allow(dead_code)] #[allow(non_camel_case_types)] diff --git a/src/google/protobuf/compiler/rust/naming.cc b/src/google/protobuf/compiler/rust/naming.cc index 1b207fa80b..e70c487c59 100644 --- a/src/google/protobuf/compiler/rust/naming.cc +++ b/src/google/protobuf/compiler/rust/naming.cc @@ -405,44 +405,6 @@ absl::string_view MultiCasePrefixStripper::StripPrefix( return name; } -PROTOBUF_CONSTINIT const MapKeyType kMapKeyTypes[] = { - {/*thunk_ident=*/"i32", /*rs_key_t=*/"i32", /*rs_ffi_key_t=*/"i32", - /*rs_to_ffi_key_expr=*/"key", /*rs_from_ffi_key_expr=*/"ffi_key", - /*cc_key_t=*/"int32_t", /*cc_ffi_key_t=*/"int32_t", - /*cc_from_ffi_key_expr=*/"key", - /*cc_to_ffi_key_expr=*/"cpp_key"}, - {/*thunk_ident=*/"u32", /*rs_key_t=*/"u32", /*rs_ffi_key_t=*/"u32", - /*rs_to_ffi_key_expr=*/"key", /*rs_from_ffi_key_expr=*/"ffi_key", - /*cc_key_t=*/"uint32_t", /*cc_ffi_key_t=*/"uint32_t", - /*cc_from_ffi_key_expr=*/"key", - /*cc_to_ffi_key_expr=*/"cpp_key"}, - {/*thunk_ident=*/"i64", /*rs_key_t=*/"i64", /*rs_ffi_key_t=*/"i64", - /*rs_to_ffi_key_expr=*/"key", /*rs_from_ffi_key_expr=*/"ffi_key", - /*cc_key_t=*/"int64_t", /*cc_ffi_key_t=*/"int64_t", - /*cc_from_ffi_key_expr=*/"key", - /*cc_to_ffi_key_expr=*/"cpp_key"}, - {/*thunk_ident=*/"u64", /*rs_key_t=*/"u64", /*rs_ffi_key_t=*/"u64", - /*rs_to_ffi_key_expr=*/"key", /*rs_from_ffi_key_expr=*/"ffi_key", - /*cc_key_t=*/"uint64_t", /*cc_ffi_key_t=*/"uint64_t", - /*cc_from_ffi_key_expr=*/"key", - /*cc_to_ffi_key_expr=*/"cpp_key"}, - {/*thunk_ident=*/"bool", /*rs_key_t=*/"bool", /*rs_ffi_key_t=*/"bool", - /*rs_to_ffi_key_expr=*/"key", /*rs_from_ffi_key_expr=*/"ffi_key", - /*cc_key_t=*/"bool", /*cc_ffi_key_t=*/"bool", - /*cc_from_ffi_key_expr=*/"key", - /*cc_to_ffi_key_expr=*/"cpp_key"}, - {/*thunk_ident=*/"ProtoString", - /*rs_key_t=*/"$pb$::ProtoString", - /*rs_ffi_key_t=*/"$pbr$::PtrAndLen", - /*rs_to_ffi_key_expr=*/"key.as_bytes().into()", - /*rs_from_ffi_key_expr=*/ - "$pb$::ProtoStr::from_utf8_unchecked(ffi_key.as_ref())", - /*cc_key_t=*/"std::string", - /*cc_ffi_key_t=*/"google::protobuf::rust::PtrAndLen", - /*cc_from_ffi_key_expr=*/ - "std::string(key.ptr, key.len)", /*cc_to_ffi_key_expr=*/ - "google::protobuf::rust::PtrAndLen{cpp_key.data(), cpp_key.size()}"}}; - } // namespace rust } // namespace compiler } // namespace protobuf diff --git a/src/google/protobuf/map.h b/src/google/protobuf/map.h index 92a6140fda..a4509333da 100644 --- a/src/google/protobuf/map.h +++ b/src/google/protobuf/map.h @@ -1194,6 +1194,7 @@ class RustMapHelper { enum { kKeyIsString = UntypedMapBase::kKeyIsString, + kValueIsString = UntypedMapBase::kValueIsString, kValueIsProto = UntypedMapBase::kValueIsProto, };