diff --git a/rust/BUILD b/rust/BUILD index 2bd06aed0f..bc6a5cc9f4 100644 --- a/rust/BUILD +++ b/rust/BUILD @@ -56,6 +56,7 @@ PROTOBUF_SHARED = [ "shared.rs", "string.rs", "vtable.rs", + "map.rs", ] # The Rust Protobuf runtime using the upb kernel. @@ -65,11 +66,11 @@ PROTOBUF_SHARED = [ # setting. rust_library( name = "protobuf_upb", - srcs = PROTOBUF_SHARED + [ - "map.rs", - "upb.rs", - ], + srcs = PROTOBUF_SHARED + ["upb.rs"], crate_root = "shared.rs", + proc_macro_deps = [ + "@crate_index//:paste", + ], rustc_flags = ["--cfg=upb_kernel"], visibility = [ "//src/google/protobuf:__subpackages__", diff --git a/rust/cpp.rs b/rust/cpp.rs index 4346d4173f..11fd63462c 100644 --- a/rust/cpp.rs +++ b/rust/cpp.rs @@ -7,7 +7,7 @@ // Rust Protobuf runtime using the C++ kernel. -use crate::__internal::{Private, RawArena, RawMessage, RawRepeatedField}; +use crate::__internal::{Private, RawArena, RawMap, RawMessage, RawRepeatedField}; use paste::paste; use std::alloc::Layout; use std::cell::UnsafeCell; @@ -319,6 +319,143 @@ impl<'msg, T: RepeatedScalarOps> RepeatedField<'msg, T> { } } +#[derive(Debug)] +pub struct Map<'msg, K: ?Sized, V: ?Sized> { + inner: MapInner<'msg>, + _phantom_key: PhantomData<&'msg mut K>, + _phantom_value: PhantomData<&'msg mut V>, +} + +#[derive(Clone, Copy, Debug)] +pub struct MapInner<'msg> { + pub raw: RawMap, + pub _phantom: PhantomData<&'msg ()>, +} + +// These use manual impls instead of derives to avoid unnecessary bounds on `K` +// and `V`. This problem is referred to as "perfect derive". +// https://smallcultfollowing.com/babysteps/blog/2022/04/12/implied-bounds-and-perfect-derive/ +impl<'msg, K: ?Sized, V: ?Sized> Copy for Map<'msg, K, V> {} +impl<'msg, K: ?Sized, V: ?Sized> Clone for Map<'msg, K, V> { + fn clone(&self) -> Map<'msg, K, V> { + *self + } +} + +impl<'msg, K: ?Sized, V: ?Sized> Map<'msg, K, V> { + pub fn from_inner(_private: Private, inner: MapInner<'msg>) -> Self { + Map { inner, _phantom_key: PhantomData, _phantom_value: PhantomData } + } +} + +macro_rules! impl_scalar_map_values { + ($kt:ty, $trait:ident for $($t:ty),*) => { + paste! { $( + extern "C" { + fn [< __pb_rust_Map_ $kt _ $t _new >]() -> RawMap; + fn [< __pb_rust_Map_ $kt _ $t _clear >](m: RawMap); + fn [< __pb_rust_Map_ $kt _ $t _size >](m: RawMap) -> usize; + fn [< __pb_rust_Map_ $kt _ $t _insert >](m: RawMap, key: $kt, value: $t); + fn [< __pb_rust_Map_ $kt _ $t _get >](m: RawMap, key: $kt, value: *mut $t) -> bool; + fn [< __pb_rust_Map_ $kt _ $t _remove >](m: RawMap, key: $kt, value: *mut $t) -> bool; + } + impl $trait for $t { + fn new_map() -> RawMap { + unsafe { [< __pb_rust_Map_ $kt _ $t _new >]() } + } + + fn clear(m: RawMap) { + unsafe { [< __pb_rust_Map_ $kt _ $t _clear >](m) } + } + + fn size(m: RawMap) -> usize { + unsafe { [< __pb_rust_Map_ $kt _ $t _size >](m) } + } + + fn insert(m: RawMap, key: $kt, value: $t) { + unsafe { [< __pb_rust_Map_ $kt _ $t _insert >](m, key, value) } + } + + fn get(m: RawMap, key: $kt) -> Option<$t> { + let mut val: $t = Default::default(); + let found = unsafe { [< __pb_rust_Map_ $kt _ $t _get >](m, key, &mut val) }; + if !found { + return None; + } + Some(val) + } + + fn remove(m: RawMap, key: $kt) -> Option<$t> { + let mut val: $t = Default::default(); + let removed = + unsafe { [< __pb_rust_Map_ $kt _ $t _remove >](m, key, &mut val) }; + if !removed { + return None; + } + Some(val) + } + } + )* } + } +} + +macro_rules! impl_scalar_maps { + ($($t:ty),*) => { + paste! { $( + pub trait [< MapWith $t:camel KeyOps >] { + fn new_map() -> RawMap; + fn clear(m: RawMap); + fn size(m: RawMap) -> usize; + fn insert(m: RawMap, key: $t, value: Self); + fn get(m: RawMap, key: $t) -> Option + where + Self: Sized; + fn remove(m: RawMap, key: $t) -> Option + where + Self: Sized; + } + + impl_scalar_map_values!( + $t, [< MapWith $t:camel KeyOps >] for i32, u32, f32, f64, bool, u64, i64 + ); + + impl<'msg, V: [< MapWith $t:camel KeyOps >]> Map<'msg, $t, V> { + pub fn new() -> Self { + let inner = MapInner { raw: V::new_map(), _phantom: PhantomData }; + Map { + inner, + _phantom_key: PhantomData, + _phantom_value: PhantomData + } + } + + pub fn size(&self) -> usize { + V::size(self.inner.raw) + } + + pub fn clear(&mut self) { + V::clear(self.inner.raw) + } + + pub fn get(&self, key: $t) -> Option { + V::get(self.inner.raw, key) + } + + pub fn remove(&mut self, key: $t) -> Option { + V::remove(self.inner.raw, key) + } + + pub fn insert(&mut self, key: $t, value: V) -> bool { + V::insert(self.inner.raw, key, value); + true + } + } + )* } + } +} + +impl_scalar_maps!(i32, u32, bool, u64, i64); + #[cfg(test)] mod tests { use super::*; @@ -362,4 +499,44 @@ mod tests { r.push(true); assert_that!(r.get(0), eq(Some(true))); } + + #[test] + fn i32_i32_map() { + let mut map = Map::<'_, i32, i32>::new(); + assert_that!(map.size(), eq(0)); + + assert_that!(map.insert(1, 2), eq(true)); + assert_that!(map.get(1), eq(Some(2))); + assert_that!(map.get(3), eq(None)); + assert_that!(map.size(), eq(1)); + + assert_that!(map.remove(1), eq(Some(2))); + assert_that!(map.size(), eq(0)); + assert_that!(map.remove(1), eq(None)); + + assert_that!(map.insert(4, 5), eq(true)); + assert_that!(map.insert(6, 7), eq(true)); + map.clear(); + assert_that!(map.size(), eq(0)); + } + + #[test] + fn i64_f64_map() { + let mut map = Map::<'_, i64, f64>::new(); + assert_that!(map.size(), eq(0)); + + assert_that!(map.insert(1, 2.5), eq(true)); + assert_that!(map.get(1), eq(Some(2.5))); + assert_that!(map.get(3), eq(None)); + assert_that!(map.size(), eq(1)); + + assert_that!(map.remove(1), eq(Some(2.5))); + assert_that!(map.size(), eq(0)); + assert_that!(map.remove(1), eq(None)); + + assert_that!(map.insert(4, 5.1), eq(true)); + assert_that!(map.insert(6, 7.2), eq(true)); + map.clear(); + assert_that!(map.size(), eq(0)); + } } diff --git a/rust/cpp_kernel/cpp_api.cc b/rust/cpp_kernel/cpp_api.cc index 1381611064..97f2c3cdd0 100644 --- a/rust/cpp_kernel/cpp_api.cc +++ b/rust/cpp_kernel/cpp_api.cc @@ -1,3 +1,4 @@ +#include "google/protobuf/map.h" #include "google/protobuf/repeated_field.h" extern "C" { @@ -36,4 +37,61 @@ expose_repeated_field_methods(uint64_t, u64); expose_repeated_field_methods(int64_t, i64); #undef expose_repeated_field_methods + +#define expose_scalar_map_methods(key_ty, rust_key_ty, value_ty, \ + rust_value_ty) \ + google::protobuf::Map* \ + __pb_rust_Map_##rust_key_ty##_##rust_value_ty##_new() { \ + return new google::protobuf::Map(); \ + } \ + void __pb_rust_Map_##rust_key_ty##_##rust_value_ty##_clear( \ + google::protobuf::Map* m) { \ + m->clear(); \ + } \ + size_t __pb_rust_Map_##rust_key_ty##_##rust_value_ty##_size( \ + google::protobuf::Map* m) { \ + return m->size(); \ + } \ + void __pb_rust_Map_##rust_key_ty##_##rust_value_ty##_insert( \ + google::protobuf::Map* m, key_ty key, value_ty val) { \ + (*m)[key] = val; \ + } \ + bool __pb_rust_Map_##rust_key_ty##_##rust_value_ty##_get( \ + google::protobuf::Map* m, key_ty key, value_ty* value) { \ + auto it = m->find(key); \ + if (it == m->end()) { \ + return false; \ + } \ + *value = it->second; \ + return true; \ + } \ + bool __pb_rust_Map_##rust_key_ty##_##rust_value_ty##_remove( \ + google::protobuf::Map* m, key_ty key, value_ty* value) { \ + auto it = m->find(key); \ + if (it == m->end()) { \ + return false; \ + } else { \ + *value = it->second; \ + m->erase(it); \ + return true; \ + } \ + } + +#define expose_scalar_map_methods_for_key_type(key_ty, rust_key_ty) \ + expose_scalar_map_methods(key_ty, rust_key_ty, int32_t, i32); \ + expose_scalar_map_methods(key_ty, rust_key_ty, uint32_t, u32); \ + expose_scalar_map_methods(key_ty, rust_key_ty, float, f32); \ + expose_scalar_map_methods(key_ty, rust_key_ty, double, f64); \ + expose_scalar_map_methods(key_ty, rust_key_ty, bool, bool); \ + expose_scalar_map_methods(key_ty, rust_key_ty, uint64_t, u64); \ + expose_scalar_map_methods(key_ty, rust_key_ty, int64_t, i64); + +expose_scalar_map_methods_for_key_type(int32_t, i32); +expose_scalar_map_methods_for_key_type(uint32_t, u32); +expose_scalar_map_methods_for_key_type(bool, bool); +expose_scalar_map_methods_for_key_type(uint64_t, u64); +expose_scalar_map_methods_for_key_type(int64_t, i64); + +#undef expose_scalar_map_methods +#undef expose_map_methods } diff --git a/rust/map.rs b/rust/map.rs index 450b854f31..ab54698ee3 100644 --- a/rust/map.rs +++ b/rust/map.rs @@ -7,8 +7,12 @@ use crate::{ __internal::Private, - __runtime::{Map, MapInner, MapValueType}, + __runtime::{ + Map, MapInner, MapWithBoolKeyOps, MapWithI32KeyOps, MapWithI64KeyOps, MapWithU32KeyOps, + MapWithU64KeyOps, + }, }; +use paste::paste; #[derive(Clone, Copy)] #[repr(transparent)] @@ -26,14 +30,6 @@ impl<'a, K: ?Sized, V: ?Sized> MapView<'a, K, V> { pub fn from_inner(_private: Private, inner: MapInner<'a>) -> Self { Self { inner: Map::<'a, K, V>::from_inner(_private, inner) } } - - pub fn len(&self) -> usize { - self.inner.len() - } - - pub fn is_empty(&self) -> bool { - self.len() == 0 - } } impl<'a, K: ?Sized, V: ?Sized> MapMut<'a, K, V> { @@ -44,14 +40,22 @@ impl<'a, K: ?Sized, V: ?Sized> MapMut<'a, K, V> { macro_rules! impl_scalar_map_keys { ($(key_type $type:ty;)*) => { - $( - impl<'a, V: MapValueType> MapView<'a, $type, V> { + paste! { $( + impl<'a, V: [< MapWith $type:camel KeyOps >]> MapView<'a, $type, V> { pub fn get(&self, key: $type) -> Option { self.inner.get(key) } + + pub fn len(&self) -> usize { + self.inner.size() + } + + pub fn is_empty(&self) -> bool { + self.len() == 0 + } } - impl<'a, V: MapValueType> MapMut<'a, $type, V> { + impl<'a, V: [< MapWith $type:camel KeyOps >]> MapMut<'a, $type, V> { pub fn insert(&mut self, key: $type, value: V) -> bool { self.inner.insert(key, value) } @@ -64,7 +68,7 @@ macro_rules! impl_scalar_map_keys { self.inner.clear() } } - )* + )* } }; } diff --git a/rust/shared.rs b/rust/shared.rs index 5aa054c0eb..26f0629da1 100644 --- a/rust/shared.rs +++ b/rust/shared.rs @@ -17,7 +17,6 @@ use std::fmt; /// These are the items protobuf users can access directly. #[doc(hidden)] pub mod __public { - #[cfg(upb_kernel)] pub use crate::map::{MapMut, MapView}; pub use crate::optional::{AbsentField, FieldEntry, Optional, PresentField}; pub use crate::primitive::{PrimitiveMut, SingularPrimitiveMut}; @@ -46,7 +45,6 @@ pub mod __runtime; pub mod __runtime; mod macros; -#[cfg(upb_kernel)] mod map; mod optional; mod primitive; diff --git a/rust/test/BUILD b/rust/test/BUILD index 291e0df234..03b16fb16e 100644 --- a/rust/test/BUILD +++ b/rust/test/BUILD @@ -313,6 +313,21 @@ rust_upb_proto_library( deps = [":nested_proto"], ) +cc_proto_library( + name = "map_unittest_cc_proto", + testonly = True, + deps = ["//src/google/protobuf:map_unittest_proto"], +) + +rust_cc_proto_library( + name = "map_unittest_cc_rust_proto", + testonly = True, + visibility = [ + "//rust/test/shared:__subpackages__", + ], + deps = [":map_unittest_cc_proto"], +) + rust_upb_proto_library( name = "map_unittest_upb_rust_proto", testonly = True, diff --git a/rust/test/shared/BUILD b/rust/test/shared/BUILD index 06f56338bd..04988de083 100644 --- a/rust/test/shared/BUILD +++ b/rust/test/shared/BUILD @@ -280,12 +280,32 @@ rust_test( ], ) +rust_test( + name = "accessors_map_cpp_test", + srcs = ["accessors_map_test.rs"], + proc_macro_deps = [ + "@crate_index//:paste", + ], + tags = [ + # TODO: Enable testing on arm once we support sanitizers for Rust on Arm. + "not_build:arm", + ], + deps = [ + "@crate_index//:googletest", + "//rust/test:map_unittest_cc_rust_proto", + ], +) + rust_test( name = "accessors_map_upb_test", srcs = ["accessors_map_test.rs"], proc_macro_deps = [ "@crate_index//:paste", ], + tags = [ + # TODO: Enable testing on arm once we support sanitizers for Rust on Arm. + "not_build:arm", + ], deps = [ "@crate_index//:googletest", "//rust/test:map_unittest_upb_rust_proto", diff --git a/rust/upb.rs b/rust/upb.rs index 136f4edbbc..8f713dad69 100644 --- a/rust/upb.rs +++ b/rust/upb.rs @@ -8,6 +8,7 @@ //! UPB FFI wrapper code for use by Rust Protobuf. use crate::__internal::{Private, PtrAndLen, RawArena, RawMap, RawMessage, RawRepeatedField}; +use paste::paste; use std::alloc; use std::alloc::Layout; use std::cell::UnsafeCell; @@ -376,7 +377,7 @@ extern "C" { } macro_rules! impl_repeated_primitives { - ($(($rs_type:ty, $union_field:ident, $upb_tag:expr)),*) => { + ($(($rs_type:ty, $ufield:ident, $upb_tag:expr)),*) => { $( impl<'msg> RepeatedField<'msg, $rs_type> { #[allow(dead_code)] @@ -392,7 +393,7 @@ macro_rules! impl_repeated_primitives { pub fn push(&mut self, val: $rs_type) { unsafe { upb_Array_Append( self.inner.raw, - upb_MessageValue { $union_field: val }, + upb_MessageValue { $ufield: val }, self.inner.arena.raw(), ) } } @@ -400,7 +401,7 @@ macro_rules! impl_repeated_primitives { if i >= self.len() { None } else { - unsafe { Some(upb_Array_Get(self.inner.raw, i).$union_field) } + unsafe { Some(upb_Array_Get(self.inner.raw, i).$ufield) } } } pub fn set(&self, i: usize, val: $rs_type) { @@ -410,7 +411,7 @@ macro_rules! impl_repeated_primitives { unsafe { upb_Array_Set( self.inner.raw, i, - upb_MessageValue { $union_field: val }, + upb_MessageValue { $ufield: val }, ) } } pub fn copy_from(&mut self, src: &RepeatedField<'_, $rs_type>) { @@ -511,137 +512,132 @@ impl<'msg, K: ?Sized, V: ?Sized> Clone for Map<'msg, K, V> { } impl<'msg, K: ?Sized, V: ?Sized> Map<'msg, K, V> { - pub fn len(&self) -> usize { - unsafe { upb_Map_Size(self.inner.raw) } - } - - pub fn is_empty(&self) -> bool { - self.len() == 0 - } - pub fn from_inner(_private: Private, inner: MapInner<'msg>) -> Self { Map { inner, _phantom_key: PhantomData, _phantom_value: PhantomData } } - - pub fn clear(&mut self) { - unsafe { upb_Map_Clear(self.inner.raw) } - } -} - -/// # Safety -/// Implementers of this trait must ensure that `pack_message_value` returns -/// a `upb_MessageValue` with the active variant indicated by `Self`. -pub unsafe trait MapType { - /// # Safety - /// The active variant of `outer` must be the `type PrimitiveValue` - unsafe fn unpack_message_value(_private: Private, outer: upb_MessageValue) -> Self; - - fn pack_message_value(_private: Private, inner: Self) -> upb_MessageValue; - - fn upb_ctype(_private: Private) -> UpbCType; - - fn zero_value(_private: Private) -> Self; } -/// Types implementing this trait can be used as map keys. -pub trait MapKeyType: MapType {} +macro_rules! impl_scalar_map_for_key_type { + ($key_t:ty, $key_ufield:ident, $key_upb_tag:expr, $trait:ident for $($t:ty, $ufield:ident, $upb_tag:expr, $zero_val:literal;)*) => { + paste! { $( + impl $trait for $t { + fn new_map(a: RawArena) -> RawMap { + unsafe { upb_Map_New(a, $key_upb_tag, $upb_tag) } + } -/// Types implementing this trait can be used as map values. -pub trait MapValueType: MapType {} + fn clear(m: RawMap) { + unsafe { upb_Map_Clear(m) } + } -macro_rules! impl_scalar_map_value_types { - ($($type:ty, $union_field:ident, $upb_tag:expr, $zero_val:literal;)*) => { - $( - unsafe impl MapType for $type { - unsafe fn unpack_message_value(_private: Private, outer: upb_MessageValue) -> Self { - unsafe { outer.$union_field } + fn size(m: RawMap) -> usize { + unsafe { upb_Map_Size(m) } } - fn pack_message_value(_private: Private, inner: Self) -> upb_MessageValue { - upb_MessageValue { $union_field: inner } + fn insert(m: RawMap, a: RawArena, key: $key_t, value: $t) -> bool { + unsafe { + upb_Map_Set( + m, + upb_MessageValue { $key_ufield: key }, + upb_MessageValue { $ufield: value}, + a + ) + } } - fn upb_ctype(_private: Private) -> UpbCType { - $upb_tag + fn get(m: RawMap, key: $key_t) -> Option<$t> { + let mut val = upb_MessageValue { $ufield: $zero_val }; + let found = unsafe { + upb_Map_Get(m, upb_MessageValue { $key_ufield: key }, &mut val) + }; + if !found { + return None; + } + Some(unsafe { val.$ufield }) } - fn zero_value(_private: Private) -> Self { - $zero_val + fn remove(m: RawMap, key: $key_t) -> Option<$t> { + let mut val = upb_MessageValue { $ufield: $zero_val }; + let removed = unsafe { + upb_Map_Delete(m, upb_MessageValue { $key_ufield: key }, &mut val) + }; + if !removed { + return None; + } + Some(unsafe { val.$ufield }) } } + )* } + } +} + +macro_rules! impl_scalar_map_for_key_types { + ($($t:ty, $ufield:ident, $upb_tag:expr;)*) => { + paste! { $( + pub trait [< MapWith $t:camel KeyOps >] { + fn new_map(a: RawArena) -> RawMap; + fn clear(m: RawMap); + fn size(m: RawMap) -> usize; + fn insert(m: RawMap, a: RawArena, key: $t, value: Self) -> bool; + fn get(m: RawMap, key: $t) -> Option + where + Self: Sized; + fn remove(m: RawMap, key: $t) -> Option + where + Self: Sized; + } - impl MapValueType for $type {} - )* - }; -} - -impl_scalar_map_value_types!( - f32, float_val, UpbCType::Float, 0f32; - f64, double_val, UpbCType::Double, 0f64; - i32, int32_val, UpbCType::Int32, 0i32; - u32, uint32_val, UpbCType::UInt32, 0u32; - i64, int64_val, UpbCType::Int64, 0i64; - u64, uint64_val, UpbCType::UInt64, 0u64; - bool, bool_val, UpbCType::Bool, false; -); - -macro_rules! impl_scalar_map_key_types { - ($($type:ty;)*) => { - $( - impl MapKeyType for $type {} - )* - }; -} + impl_scalar_map_for_key_type!($t, $ufield, $upb_tag, [< MapWith $t:camel KeyOps >] for + f32, float_val, UpbCType::Float, 0f32; + f64, double_val, UpbCType::Double, 0f64; + i32, int32_val, UpbCType::Int32, 0i32; + u32, uint32_val, UpbCType::UInt32, 0u32; + i64, int64_val, UpbCType::Int64, 0i64; + u64, uint64_val, UpbCType::UInt64, 0u64; + bool, bool_val, UpbCType::Bool, false; + ); + + impl<'msg, V: [< MapWith $t:camel KeyOps >]> Map<'msg, $t, V> { + pub fn new(arena: &'msg mut Arena) -> Self { + let inner = MapInner { raw: V::new_map(arena.raw()), arena }; + Map { + inner, + _phantom_key: PhantomData, + _phantom_value: PhantomData + } + } -impl_scalar_map_key_types!( - i32; u32; i64; u64; bool; -); + pub fn size(&self) -> usize { + V::size(self.inner.raw) + } -impl<'msg, K: MapKeyType, V: MapValueType> Map<'msg, K, V> { - pub fn new(arena: &'msg Arena) -> Self { - unsafe { - let raw_map = upb_Map_New(arena.raw(), K::upb_ctype(Private), V::upb_ctype(Private)); - Map { - inner: MapInner { raw: raw_map, arena }, - _phantom_key: PhantomData, - _phantom_value: PhantomData, - } - } - } + pub fn clear(&mut self) { + V::clear(self.inner.raw) + } - pub fn get(&self, key: K) -> Option { - let mut val = V::pack_message_value(Private, V::zero_value(Private)); - let found = - unsafe { upb_Map_Get(self.inner.raw, K::pack_message_value(Private, key), &mut val) }; - if !found { - return None; - } - Some(unsafe { V::unpack_message_value(Private, val) }) - } + pub fn get(&self, key: $t) -> Option { + V::get(self.inner.raw, key) + } - pub fn insert(&mut self, key: K, value: V) -> bool { - unsafe { - upb_Map_Set( - self.inner.raw, - K::pack_message_value(Private, key), - V::pack_message_value(Private, value), - self.inner.arena.raw(), - ) - } - } + pub fn remove(&mut self, key: $t) -> Option { + V::remove(self.inner.raw, key) + } - pub fn remove(&mut self, key: K) -> Option { - let mut val = V::pack_message_value(Private, V::zero_value(Private)); - let removed = unsafe { - upb_Map_Delete(self.inner.raw, K::pack_message_value(Private, key), &mut val) - }; - if !removed { - return None; - } - Some(unsafe { V::unpack_message_value(Private, val) }) + pub fn insert(&mut self, key: $t, value: V) -> bool { + V::insert(self.inner.raw, self.inner.arena.raw(), key, value) + } + } + )* } } } +impl_scalar_map_for_key_types!( + i32, int32_val, UpbCType::Int32; + u32, uint32_val, UpbCType::UInt32; + i64, int64_val, UpbCType::Int64; + u64, uint64_val, UpbCType::UInt64; + bool, bool_val, UpbCType::Bool; +); + extern "C" { fn upb_Map_New(arena: RawArena, key_type: UpbCType, value_type: UpbCType) -> RawMap; fn upb_Map_Size(map: RawMap) -> usize; @@ -720,43 +716,43 @@ mod tests { #[test] fn i32_i32_map() { - let arena = Arena::new(); - let mut map = Map::<'_, i32, i32>::new(&arena); - assert_that!(map.len(), eq(0)); + let mut arena = Arena::new(); + let mut map = Map::<'_, i32, i32>::new(&mut arena); + assert_that!(map.size(), eq(0)); assert_that!(map.insert(1, 2), eq(true)); assert_that!(map.get(1), eq(Some(2))); assert_that!(map.get(3), eq(None)); - assert_that!(map.len(), eq(1)); + assert_that!(map.size(), eq(1)); assert_that!(map.remove(1), eq(Some(2))); - assert_that!(map.len(), eq(0)); + assert_that!(map.size(), eq(0)); assert_that!(map.remove(1), eq(None)); assert_that!(map.insert(4, 5), eq(true)); assert_that!(map.insert(6, 7), eq(true)); map.clear(); - assert_that!(map.len(), eq(0)); + assert_that!(map.size(), eq(0)); } #[test] fn i64_f64_map() { - let arena = Arena::new(); - let mut map = Map::<'_, i64, f64>::new(&arena); - assert_that!(map.len(), eq(0)); + let mut arena = Arena::new(); + let mut map = Map::<'_, i64, f64>::new(&mut arena); + assert_that!(map.size(), eq(0)); assert_that!(map.insert(1, 2.5), eq(true)); assert_that!(map.get(1), eq(Some(2.5))); assert_that!(map.get(3), eq(None)); - assert_that!(map.len(), eq(1)); + assert_that!(map.size(), eq(1)); assert_that!(map.remove(1), eq(Some(2.5))); - assert_that!(map.len(), eq(0)); + assert_that!(map.size(), eq(0)); assert_that!(map.remove(1), eq(None)); assert_that!(map.insert(4, 5.1), eq(true)); assert_that!(map.insert(6, 7.2), eq(true)); map.clear(); - assert_that!(map.len(), eq(0)); + assert_that!(map.size(), eq(0)); } } diff --git a/src/google/protobuf/compiler/rust/accessors/accessor_generator.h b/src/google/protobuf/compiler/rust/accessors/accessor_generator.h index dcefbe8ae4..442de3c7cc 100644 --- a/src/google/protobuf/compiler/rust/accessors/accessor_generator.h +++ b/src/google/protobuf/compiler/rust/accessors/accessor_generator.h @@ -111,6 +111,7 @@ class Map final : public AccessorGenerator { ~Map() override = default; void InMsgImpl(Context field) const override; void InExternC(Context field) const override; + void InThunkCc(Context field) const override; }; } // namespace rust diff --git a/src/google/protobuf/compiler/rust/accessors/map.cc b/src/google/protobuf/compiler/rust/accessors/map.cc index f7eda845c4..1df6b23c42 100644 --- a/src/google/protobuf/compiler/rust/accessors/map.cc +++ b/src/google/protobuf/compiler/rust/accessors/map.cc @@ -5,6 +5,7 @@ // license that can be found in the LICENSE file or at // https://developers.google.com/open-source/licenses/bsd +#include "google/protobuf/compiler/cpp/helpers.h" #include "google/protobuf/compiler/rust/accessors/accessor_generator.h" #include "google/protobuf/compiler/rust/context.h" #include "google/protobuf/compiler/rust/naming.h" @@ -24,33 +25,55 @@ void Map::InMsgImpl(Context field) const { {"Key", PrimitiveRsTypeName(key_type)}, {"Value", PrimitiveRsTypeName(value_type)}, {"getter_thunk", Thunk(field, "get")}, - {"getter_thunk_mut", Thunk(field, "get_mut")}, + {"getter_mut_thunk", Thunk(field, "get_mut")}, {"getter", [&] { if (field.is_upb()) { field.Emit({}, R"rs( - pub fn r#$field$(&self) -> $pb$::MapView<'_, $Key$, $Value$> { - let inner = unsafe { - $getter_thunk$(self.inner.msg) - }.map_or_else(|| unsafe {$pbr$::empty_map()}, |raw| { - $pbr$::MapInner{ raw, arena: &self.inner.arena } - }); - $pb$::MapView::from_inner($pbi$::Private, inner) - } - - pub fn r#$field$_mut(&mut self) - -> $pb$::MapMut<'_, $Key$, $Value$> { - let raw = unsafe { - $getter_thunk_mut$(self.inner.msg, self.inner.arena.raw()) - }; - let inner = $pbr$::MapInner{ raw, arena: &self.inner.arena }; - $pb$::MapMut::from_inner($pbi$::Private, inner) - } - )rs"); + pub fn r#$field$(&self) -> $pb$::MapView<'_, $Key$, $Value$> { + let inner = unsafe { + $getter_thunk$(self.inner.msg) + }.map_or_else(|| unsafe {$pbr$::empty_map()}, |raw| { + $pbr$::MapInner{ raw, arena: &self.inner.arena } + }); + $pb$::MapView::from_inner($pbi$::Private, inner) + })rs"); + } else { + field.Emit({}, R"rs( + pub fn r#$field$(&self) -> $pb$::MapView<'_, $Key$, $Value$> { + let inner = $pbr$::MapInner { + raw: unsafe { $getter_thunk$(self.inner.msg) }, + _phantom: std::marker::PhantomData + }; + $pb$::MapView::from_inner($pbi$::Private, inner) + })rs"); + } + }}, + {"getter_mut", + [&] { + if (field.is_upb()) { + field.Emit({}, R"rs( + pub fn r#$field$_mut(&mut self) -> $pb$::MapMut<'_, $Key$, $Value$> { + let raw = unsafe { + $getter_mut_thunk$(self.inner.msg, self.inner.arena.raw()) + }; + let inner = $pbr$::MapInner{ raw, arena: &self.inner.arena }; + $pb$::MapMut::from_inner($pbi$::Private, inner) + })rs"); + } else { + field.Emit({}, R"rs( + pub fn r#$field$_mut(&mut self) -> $pb$::MapMut<'_, $Key$, $Value$> { + let inner = $pbr$::MapInner { + raw: unsafe { $getter_mut_thunk$(self.inner.msg) }, + _phantom: std::marker::PhantomData + }; + $pb$::MapMut::from_inner($pbi$::Private, inner) + })rs"); } }}}, R"rs( $getter$ + $getter_mut$ )rs"); } @@ -58,16 +81,21 @@ void Map::InExternC(Context field) const { field.Emit( { {"getter_thunk", Thunk(field, "get")}, - {"getter_thunk_mut", Thunk(field, "get_mut")}, + {"getter_mut_thunk", Thunk(field, "get_mut")}, {"getter", [&] { if (field.is_upb()) { field.Emit({}, R"rs( - fn $getter_thunk$(raw_msg: $pbi$::RawMessage) - -> Option<$pbi$::RawMap>; - fn $getter_thunk_mut$(raw_msg: $pbi$::RawMessage, - arena: $pbi$::RawArena) -> $pbi$::RawMap; - )rs"); + fn $getter_thunk$(raw_msg: $pbi$::RawMessage) + -> Option<$pbi$::RawMap>; + fn $getter_mut_thunk$(raw_msg: $pbi$::RawMessage, + arena: $pbi$::RawArena) -> $pbi$::RawMap; + )rs"); + } else { + field.Emit({}, R"rs( + fn $getter_thunk$(msg: $pbi$::RawMessage) -> $pbi$::RawMap; + fn $getter_mut_thunk$(msg: $pbi$::RawMessage,) -> $pbi$::RawMap; + )rs"); } }}, }, @@ -76,6 +104,30 @@ void Map::InExternC(Context field) const { )rs"); } +void Map::InThunkCc(Context field) const { + field.Emit( + {{"field", cpp::FieldName(&field.desc())}, + {"Key", cpp::PrimitiveTypeName( + field.desc().message_type()->map_key()->cpp_type())}, + {"Value", cpp::PrimitiveTypeName( + field.desc().message_type()->map_value()->cpp_type())}, + {"QualifiedMsg", + cpp::QualifiedClassName(field.desc().containing_type())}, + {"getter_thunk", Thunk(field, "get")}, + {"getter_mut_thunk", Thunk(field, "get_mut")}, + {"impls", + [&] { + field.Emit( + R"cc( + const void* $getter_thunk$($QualifiedMsg$& msg) { + return &msg.$field$(); + } + void* $getter_mut_thunk$($QualifiedMsg$* msg) { return msg->mutable_$field$(); } + )cc"); + }}}, + "$impls$"); +} + } // namespace rust } // namespace compiler } // namespace protobuf