From 1142838bbc88727c9f65f2227c748cf976537f06 Mon Sep 17 00:00:00 2001 From: Adam Cozzette Date: Thu, 12 Sep 2024 11:22:07 -0700 Subject: [PATCH] Rust: get all types onto a single blanket ProxiedInMapValue implementation for upb I realized that as long as we implement `UpbTypeConversions` for enums, we can easily get the blanket implementation for messages to work for enums as well. Luckily the blanket implementation also happens to work for non-generated types, so this gets us down to just one ProxiedInMapValue implementation for upb. PiperOrigin-RevId: 673927343 --- rust/upb.rs | 104 +-------------------- src/google/protobuf/compiler/rust/enum.cc | 108 ++++------------------ 2 files changed, 23 insertions(+), 189 deletions(-) diff --git a/rust/upb.rs b/rust/upb.rs index 5cd4bb4704..ddcd3acd17 100644 --- a/rust/upb.rs +++ b/rust/upb.rs @@ -9,8 +9,8 @@ use crate::__internal::{Enum, Private, SealedInternal}; use crate::{ - IntoProxied, Map, MapIter, MapMut, MapView, Message, Mut, ProtoBytes, ProtoStr, ProtoString, - Proxied, ProxiedInMapValue, ProxiedInRepeated, Repeated, RepeatedMut, RepeatedView, View, + IntoProxied, Map, MapIter, MapMut, MapView, Mut, ProtoBytes, ProtoStr, ProtoString, Proxied, + ProxiedInMapValue, ProxiedInRepeated, Repeated, RepeatedMut, RepeatedView, View, }; use core::fmt::Debug; use std::mem::{size_of, ManuallyDrop, MaybeUninit}; @@ -683,108 +683,10 @@ impl RawMapIter { } } -macro_rules! impl_ProxiedInMapValue_for_non_generated_value_types { - ($key_t:ty ; $($t:ty),*) => { - $( - impl ProxiedInMapValue<$key_t> for $t { - fn map_new(_private: Private) -> Map<$key_t, Self> { - let arena = Arena::new(); - let raw = unsafe { - upb_Map_New(arena.raw(), - <$key_t as UpbTypeConversions>::upb_type(), - <$t as UpbTypeConversions>::upb_type()) - }; - Map::from_inner(Private, InnerMap { raw, arena }) - } - - unsafe fn map_free(_private: Private, _map: &mut Map<$key_t, Self>) { - // No-op: the memory will be dropped by the arena. - } - - fn map_clear(mut map: MapMut<$key_t, Self>) { - unsafe { - upb_Map_Clear(map.as_raw(Private)); - } - } - - fn map_len(map: MapView<$key_t, Self>) -> usize { - unsafe { - upb_Map_Size(map.as_raw(Private)) - } - } - - fn map_insert(mut map: MapMut<$key_t, Self>, key: View<'_, $key_t>, value: impl IntoProxied) -> bool { - let arena = map.raw_arena(Private); - unsafe { - upb_Map_InsertAndReturnIfInserted( - map.as_raw(Private), - <$key_t as UpbTypeConversions>::to_message_value(key), - <$t as UpbTypeConversions>::into_message_value_fuse_if_required(arena, value.into_proxied(Private)), - arena - ) - } - } - - fn map_get<'a>(map: MapView<'a, $key_t, Self>, key: View<'_, $key_t>) -> Option> { - let mut val = MaybeUninit::uninit(); - let found = unsafe { - upb_Map_Get(map.as_raw(Private), <$key_t as UpbTypeConversions>::to_message_value(key), - val.as_mut_ptr()) - }; - if !found { - return None; - } - Some(unsafe { <$t as UpbTypeConversions>::from_message_value(val.assume_init()) }) - } - - fn map_remove(mut map: MapMut<$key_t, Self>, key: View<'_, $key_t>) -> bool { - unsafe { - upb_Map_Delete(map.as_raw(Private), - <$key_t as UpbTypeConversions>::to_message_value(key), - ptr::null_mut()) - } - } - - fn map_iter(map: MapView<$key_t, Self>) -> MapIter<$key_t, Self> { - // SAFETY: View> guarantees its RawMap outlives '_. - unsafe { - MapIter::from_raw(Private, RawMapIter::new(map.as_raw(Private))) - } - } - - fn map_iter_next<'a>( - iter: &mut MapIter<'a, $key_t, Self> - ) -> Option<(View<'a, $key_t>, View<'a, Self>)> { - // SAFETY: MapIter<'a, ..> guarantees its RawMapIter outlives 'a. - unsafe { iter.as_raw_mut(Private).next_unchecked() } - // SAFETY: MapIter returns key and values message values - // with the variants for K and V active. - .map(|(k, v)| unsafe {( - <$key_t as UpbTypeConversions>::from_message_value(k), - <$t as UpbTypeConversions>::from_message_value(v), - )}) - } - } - )* - } -} - -macro_rules! impl_ProxiedInMapValue_for_key_types { - ($($t:ty),*) => { - $( - impl_ProxiedInMapValue_for_non_generated_value_types!( - $t ; f32, f64, i32, u32, i64, u64, bool, ProtoString, ProtoBytes - ); - )* - } -} - -impl_ProxiedInMapValue_for_key_types!(i32, u32, i64, u64, bool, ProtoString); - impl ProxiedInMapValue for MessageType where Key: Proxied + UpbTypeConversions, - MessageType: Proxied + UpbTypeConversions + Message, + MessageType: Proxied + UpbTypeConversions, { fn map_new(_private: Private) -> Map { let arena = Arena::new(); diff --git a/src/google/protobuf/compiler/rust/enum.cc b/src/google/protobuf/compiler/rust/enum.cc index dd9df5708f..e02304e303 100644 --- a/src/google/protobuf/compiler/rust/enum.cc +++ b/src/google/protobuf/compiler/rust/enum.cc @@ -148,97 +148,29 @@ void EnumProxiedInMapValue(Context& ctx, const EnumDescriptor& desc) { } return; case Kernel::kUpb: - for (const auto& t : kMapKeyTypes) { - ctx.Emit({io::Printer::Sub("key_t", [&] { ctx.Emit(t.rs_key_t); }) - .WithSuffix("")}, - R"rs( - impl $pb$::ProxiedInMapValue<$key_t$> for $name$ { - fn map_new(_private: $pbi$::Private) -> $pb$::Map<$key_t$, Self> { - let arena = $pbr$::Arena::new(); - let raw = unsafe { - $pbr$::upb_Map_New( - arena.raw(), - <$key_t$ as $pbr$::UpbTypeConversions>::upb_type(), - $pbr$::CType::Enum) - }; - $pb$::Map::from_inner( - $pbi$::Private, - $pbr$::InnerMap::new(raw, arena)) - } - - unsafe fn map_free(_private: $pbi$::Private, _map: &mut $pb$::Map<$key_t$, Self>) { - // No-op: the memory will be dropped by the arena. - } - - fn map_clear(mut map: $pb$::MapMut<$key_t$, Self>) { - unsafe { - $pbr$::upb_Map_Clear(map.as_raw($pbi$::Private)); - } - } - - fn map_len(map: $pb$::MapView<$key_t$, Self>) -> usize { - unsafe { - $pbr$::upb_Map_Size(map.as_raw($pbi$::Private)) - } - } - - fn map_insert(mut map: $pb$::MapMut<$key_t$, Self>, key: $pb$::View<'_, $key_t$>, value: impl $pb$::IntoProxied) -> bool { - let arena = map.inner($pbi$::Private).raw_arena(); - unsafe { - $pbr$::upb_Map_InsertAndReturnIfInserted( - map.as_raw($pbi$::Private), - <$key_t$ as $pbr$::UpbTypeConversions>::to_message_value(key), - $pbr$::upb_MessageValue { int32_val: value.into_proxied($pbi$::Private).0 }, - arena - ) - } - } + ctx.Emit(R"rs( + impl $pbr$::UpbTypeConversions for $name$ { + fn upb_type() -> $pbr$::CType { + $pbr$::CType::Enum + } - fn map_get<'a>(map: $pb$::MapView<'a, $key_t$, Self>, key: $pb$::View<'_, $key_t$>) -> Option<$pb$::View<'a, Self>> { - let mut val = $std$::mem::MaybeUninit::uninit(); - let found = unsafe { - $pbr$::upb_Map_Get( - map.as_raw($pbi$::Private), - <$key_t$ as $pbr$::UpbTypeConversions>::to_message_value(key), - val.as_mut_ptr()) - }; - if !found { - return None; - } - Some($name$(unsafe { val.assume_init().int32_val })) - } + fn to_message_value( + val: $pb$::View<'_, Self>) -> $pbr$::upb_MessageValue { + $pbr$::upb_MessageValue { int32_val: val.0 } + } - fn map_remove(mut map: $pb$::MapMut<$key_t$, Self>, key: $pb$::View<'_, $key_t$>) -> bool { - let mut val = $std$::mem::MaybeUninit::uninit(); - unsafe { - $pbr$::upb_Map_Delete( - map.as_raw($pbi$::Private), - <$key_t$ as $pbr$::UpbTypeConversions>::to_message_value(key), - val.as_mut_ptr()) - } - } - fn map_iter(map: $pb$::MapView<$key_t$, Self>) -> $pb$::MapIter<$key_t$, Self> { - // SAFETY: View> guarantees its RawMap outlives '_. - unsafe { - $pb$::MapIter::from_raw($pbi$::Private, $pbr$::RawMapIter::new(map.as_raw($pbi$::Private))) - } - } + unsafe fn into_message_value_fuse_if_required( + raw_parent_arena: $pbr$::RawArena, + val: Self) -> $pbr$::upb_MessageValue { + $pbr$::upb_MessageValue { int32_val: val.0 } + } - fn map_iter_next<'a>( - iter: &mut $pb$::MapIter<'a, $key_t$, Self> - ) -> Option<($pb$::View<'a, $key_t$>, $pb$::View<'a, Self>)> { - // SAFETY: MapIter<'a, ..> guarantees its RawMapIter outlives 'a. - unsafe { iter.as_raw_mut($pbi$::Private).next_unchecked() } - // SAFETY: MapIter returns key and values message values - // with the variants for K and V active. - .map(|(k, v)| unsafe {( - <$key_t$ as $pbr$::UpbTypeConversions>::from_message_value(k), - Self(v.int32_val), - )}) - } - } - )rs"); - } + unsafe fn from_message_value<'msg>(val: $pbr$::upb_MessageValue) + -> $pb$::View<'msg, Self> { + $name$(unsafe { val.int32_val }) + } + } + )rs"); return; } }