Rust C++: get all map fields onto a common implementation of ProxiedInMapValue

This CL migrates messages, enums, and primitive types all onto the same blanket
implementation of the `ProxiedInMapValue` trait. This gets us to the point
where messages and enums no longer need to generate any significant amount of
extra code just in case they might be used as a map value.

There are a few big pieces to this:
 - I generalized the message-specific FFI endpoints in `rust/cpp_kernel/map.cc`
   to be able to additionally handle enums and primitive types as values. This
   mostly consisted of replacing `MessageLite*` parameters with a new `MapValue`
   tagged union.
 - On the Rust side, I added a new blanket implementation of
   `ProxiedInMapValue` in rust/cpp.rs. It relies on its value type to implement
   a new `CppMapTypeConversions` trait so that it can convert to and from the
   `MapValue` tagged union used for FFI.
 - In the Rust generated code, I deleted the generated `ProxiedInMapValue`
   implementations for messages and enums and replaced them with
   implementations of the `CppMapTypeConversions` trait.

PiperOrigin-RevId: 687355817
pull/18828/head
Adam Cozzette 3 months ago committed by Copybara-Service
parent 2ff033011f
commit cbb3edd86d
  1. 548
      rust/cpp.rs
  2. 1
      rust/cpp_kernel/BUILD
  3. 319
      rust/cpp_kernel/map.cc
  4. 122
      rust/cpp_kernel/map.h
  5. 116
      src/google/protobuf/compiler/rust/enum.cc
  6. 1
      src/google/protobuf/compiler/rust/generator.cc
  7. 133
      src/google/protobuf/compiler/rust/message.cc
  8. 38
      src/google/protobuf/compiler/rust/naming.cc
  9. 1
      src/google/protobuf/map.h

@ -762,48 +762,397 @@ impl UntypedMapIterator {
}
}
// This enum is used to pass some information about the key type of a map to
// C++. The main purpose is to indicate the size of the key so that we can
// determine the correct size and offset information of map entries on the C++
// side. We also rely on it to indicate whether the key is a string or not.
// LINT.IfChange(map_ffi)
#[doc(hidden)]
#[repr(u8)]
pub enum MapKeyCategory {
OneByte,
FourBytes,
EightBytes,
StdString,
#[derive(Debug, PartialEq)]
pub enum MapValueTag {
Bool,
U32,
U64,
String,
Message,
}
// For the purposes of FFI, we treat all numeric types of a given size the same
// way. For example, u32, i32, and f32 values are all represented as a u32.
// Likewise, u64, i64, and f64 values are all stored in a u64.
#[doc(hidden)]
pub trait MapKey {
const CATEGORY: MapKeyCategory;
#[repr(C)]
pub union MapValueUnion {
pub b: bool,
pub u: u32,
pub uu: u64,
// Generally speaking, if s is set then it should not be None. However, we
// do set it to None in the special case where the MapValue is just a
// "prototype" (see below). In that scenario, we just want to indicate the
// value type without having to allocate a real C++ std::string.
pub s: Option<CppStdString>,
pub m: RawMessage,
}
// We use this tagged union to represent map values for the purposes of FFI.
#[doc(hidden)]
#[repr(C)]
pub struct MapValue {
pub tag: MapValueTag,
pub val: MapValueUnion,
}
// LINT.ThenChange(//depot/google3/third_party/protobuf/rust/cpp_kernel/map.cc:
// map_ffi)
impl MapValue {
fn make_bool(b: bool) -> Self {
MapValue { tag: MapValueTag::Bool, val: MapValueUnion { b } }
}
pub fn make_u32(u: u32) -> Self {
MapValue { tag: MapValueTag::U32, val: MapValueUnion { u } }
}
fn make_u64(uu: u64) -> Self {
MapValue { tag: MapValueTag::U64, val: MapValueUnion { uu } }
}
fn make_string(s: CppStdString) -> Self {
MapValue { tag: MapValueTag::String, val: MapValueUnion { s: Some(s) } }
}
pub fn make_message(m: RawMessage) -> Self {
MapValue { tag: MapValueTag::Message, val: MapValueUnion { m } }
}
}
pub trait CppMapTypeConversions: Proxied {
// We have a notion of a map value "prototype", which is a MapValue that
// contains just enough information to indicate the value type of the map.
// We need this on the C++ side to be able to determine size and offset
// information about the map entry. For messages, the prototype is
// the message default instance. For all other types, it is just a MapValue
// with the appropriate tag.
fn get_prototype() -> MapValue;
fn to_map_value(self) -> MapValue;
/// # Safety
/// - `value` must store the correct type for `Self`. If it is a string or
/// bytes, then it must not be None. If `Self` is a closed enum, then
/// `value` must store a valid value for that enum. If `Self` is a
/// message, then `value` must store a message of the same type.
/// - The value must be valid for `'a` lifetime.
unsafe fn from_map_value<'a>(value: MapValue) -> View<'a, Self>;
}
impl CppMapTypeConversions for u32 {
fn get_prototype() -> MapValue {
MapValue::make_u32(0)
}
fn to_map_value(self) -> MapValue {
MapValue::make_u32(self)
}
unsafe fn from_map_value<'a>(value: MapValue) -> View<'a, Self> {
debug_assert_eq!(value.tag, MapValueTag::U32);
unsafe { value.val.u }
}
}
impl CppMapTypeConversions for i32 {
fn get_prototype() -> MapValue {
MapValue::make_u32(0)
}
fn to_map_value(self) -> MapValue {
MapValue::make_u32(self as u32)
}
unsafe fn from_map_value<'a>(value: MapValue) -> View<'a, Self> {
debug_assert_eq!(value.tag, MapValueTag::U32);
unsafe { value.val.u as i32 }
}
}
impl CppMapTypeConversions for u64 {
fn get_prototype() -> MapValue {
MapValue::make_u64(0)
}
fn to_map_value(self) -> MapValue {
MapValue::make_u64(self)
}
unsafe fn from_map_value<'a>(value: MapValue) -> View<'a, Self> {
debug_assert_eq!(value.tag, MapValueTag::U64);
unsafe { value.val.uu }
}
}
impl CppMapTypeConversions for i64 {
fn get_prototype() -> MapValue {
MapValue::make_u64(0)
}
fn to_map_value(self) -> MapValue {
MapValue::make_u64(self as u64)
}
unsafe fn from_map_value<'a>(value: MapValue) -> View<'a, Self> {
debug_assert_eq!(value.tag, MapValueTag::U64);
unsafe { value.val.uu as i64 }
}
}
impl CppMapTypeConversions for f32 {
fn get_prototype() -> MapValue {
MapValue::make_u32(0)
}
fn to_map_value(self) -> MapValue {
MapValue::make_u32(self.to_bits())
}
unsafe fn from_map_value<'a>(value: MapValue) -> View<'a, Self> {
debug_assert_eq!(value.tag, MapValueTag::U32);
unsafe { Self::from_bits(value.val.u) }
}
}
impl CppMapTypeConversions for f64 {
fn get_prototype() -> MapValue {
MapValue::make_u64(0)
}
fn to_map_value(self) -> MapValue {
MapValue::make_u64(self.to_bits())
}
unsafe fn from_map_value<'a>(value: MapValue) -> View<'a, Self> {
debug_assert_eq!(value.tag, MapValueTag::U64);
unsafe { Self::from_bits(value.val.uu) }
}
}
impl CppMapTypeConversions for bool {
fn get_prototype() -> MapValue {
MapValue::make_bool(false)
}
fn to_map_value(self) -> MapValue {
MapValue::make_bool(self)
}
unsafe fn from_map_value<'a>(value: MapValue) -> View<'a, Self> {
debug_assert_eq!(value.tag, MapValueTag::Bool);
unsafe { value.val.b }
}
}
impl CppMapTypeConversions for ProtoString {
fn get_prototype() -> MapValue {
MapValue { tag: MapValueTag::String, val: MapValueUnion { s: None } }
}
fn to_map_value(self) -> MapValue {
MapValue::make_string(protostr_into_cppstdstring(self))
}
unsafe fn from_map_value<'a>(value: MapValue) -> &'a ProtoStr {
debug_assert_eq!(value.tag, MapValueTag::String);
unsafe {
ProtoStr::from_utf8_unchecked(
ptrlen_to_str(proto2_rust_cpp_string_to_view(value.val.s.unwrap())).into(),
)
}
}
}
impl CppMapTypeConversions for ProtoBytes {
fn get_prototype() -> MapValue {
MapValue { tag: MapValueTag::String, val: MapValueUnion { s: None } }
}
fn to_map_value(self) -> MapValue {
MapValue::make_string(protobytes_into_cppstdstring(self))
}
unsafe fn from_map_value<'a>(value: MapValue) -> &'a [u8] {
debug_assert_eq!(value.tag, MapValueTag::String);
unsafe { proto2_rust_cpp_string_to_view(value.val.s.unwrap()).as_ref() }
}
}
// This trait encapsulates functionality that is specific to each map key type.
// We need this primarily so that we can call the appropriate FFI function for
// the key type.
#[doc(hidden)]
pub trait MapKey
where
Self: Proxied,
{
type FfiKey;
fn to_view<'a>(key: Self::FfiKey) -> View<'a, Self>;
unsafe fn free(m: RawMap, prototype: MapValue);
unsafe fn clear(m: RawMap, prototype: MapValue);
unsafe fn insert(m: RawMap, key: View<'_, Self>, value: MapValue) -> bool;
unsafe fn get(
m: RawMap,
prototype: MapValue,
key: View<'_, Self>,
value: *mut MapValue,
) -> bool;
unsafe fn iter_get(
iter: &mut UntypedMapIterator,
prototype: MapValue,
key: *mut Self::FfiKey,
value: *mut MapValue,
);
unsafe fn remove(m: RawMap, prototype: MapValue, key: View<'_, Self>) -> bool;
}
macro_rules! generate_map_key_impl {
( $($key:ty, $category:expr;)* ) => {
( $($key:ty, $mutable_ffi_key:ty, $to_ffi:expr, $from_ffi:expr;)* ) => {
paste! {
$(
impl MapKey for $key {
const CATEGORY: MapKeyCategory = $category;
type FfiKey = $mutable_ffi_key;
#[inline]
fn to_view<'a>(key: Self::FfiKey) -> View<'a, Self> {
$from_ffi(key)
}
#[inline]
unsafe fn free(m: RawMap, prototype: MapValue) {
unsafe { [< proto2_rust_map_free_ $key >](m, prototype) }
}
#[inline]
unsafe fn clear(m: RawMap, prototype: MapValue) {
unsafe { [< proto2_rust_map_clear_ $key >](m, prototype) }
}
#[inline]
unsafe fn insert(
m: RawMap,
key: View<'_, Self>,
value: MapValue,
) -> bool {
unsafe { [< proto2_rust_map_insert_ $key >](m, $to_ffi(key), value) }
}
#[inline]
unsafe fn get(
m: RawMap,
prototype: MapValue,
key: View<'_, Self>,
value: *mut MapValue,
) -> bool {
unsafe { [< proto2_rust_map_get_ $key >](m, prototype, $to_ffi(key), value) }
}
#[inline]
unsafe fn iter_get(
iter: &mut UntypedMapIterator,
prototype: MapValue,
key: *mut Self::FfiKey,
value: *mut MapValue,
) {
unsafe { [< proto2_rust_map_iter_get_ $key >](iter, prototype, key, value) }
}
#[inline]
unsafe fn remove(m: RawMap, prototype: MapValue, key: View<'_, Self>) -> bool {
unsafe { [< proto2_rust_map_remove_ $key >](m, prototype, $to_ffi(key)) }
}
}
)*
}
}
}
// LINT.IfChange(map_key_category)
generate_map_key_impl!(
bool, MapKeyCategory::OneByte;
i32, MapKeyCategory::FourBytes;
u32, MapKeyCategory::FourBytes;
i64, MapKeyCategory::EightBytes;
u64, MapKeyCategory::EightBytes;
ProtoString, MapKeyCategory::StdString;
bool, bool, identity, identity;
i32, i32, identity, identity;
u32, u32, identity, identity;
i64, i64, identity, identity;
u64, u64, identity, identity;
ProtoString, PtrAndLen, str_to_ptrlen, ptrlen_to_str;
);
// LINT.ThenChange(//depot/google3/third_party/protobuf/rust/cpp_kernel/map.cc:
// map_key_category)
impl<Key, Value> ProxiedInMapValue<Key> for Value
where
Key: Proxied + MapKey,
Value: Proxied + CppMapTypeConversions,
{
fn map_new(_private: Private) -> Map<Key, Self> {
unsafe { Map::from_inner(Private, InnerMap::new(proto2_rust_map_new())) }
}
unsafe fn map_free(_private: Private, map: &mut Map<Key, Self>) {
unsafe {
Key::free(map.as_raw(Private), Self::get_prototype());
}
}
fn map_clear(mut map: MapMut<Key, Self>) {
unsafe {
Key::clear(map.as_raw(Private), Self::get_prototype());
}
}
fn map_len(map: MapView<Key, Self>) -> usize {
unsafe { proto2_rust_map_size(map.as_raw(Private)) }
}
fn map_insert(
mut map: MapMut<Key, Self>,
key: View<'_, Key>,
value: impl IntoProxied<Self>,
) -> bool {
unsafe { Key::insert(map.as_raw(Private), key, value.into_proxied(Private).to_map_value()) }
}
fn map_get<'a>(map: MapView<'a, Key, Self>, key: View<'_, Key>) -> Option<View<'a, Self>> {
let mut value = std::mem::MaybeUninit::uninit();
let found = unsafe {
Key::get(map.as_raw(Private), Self::get_prototype(), key, value.as_mut_ptr())
};
if !found {
return None;
}
unsafe { Some(Self::from_map_value(value.assume_init())) }
}
fn map_remove(mut map: MapMut<Key, Self>, key: View<'_, Key>) -> bool {
unsafe { Key::remove(map.as_raw(Private), Self::get_prototype(), key) }
}
fn map_iter(map: MapView<Key, Self>) -> MapIter<Key, Self> {
// SAFETY:
// - The backing map for `map.as_raw` is valid for at least '_.
// - A View that is live for '_ guarantees the backing map is unmodified for '_.
// - The `iter` function produces an iterator that is valid for the key and
// value types, and live for at least '_.
unsafe { MapIter::from_raw(Private, proto2_rust_map_iter(map.as_raw(Private))) }
}
fn map_iter_next<'a>(
iter: &mut MapIter<'a, Key, Self>,
) -> Option<(View<'a, Key>, View<'a, Self>)> {
// SAFETY:
// - The `MapIter` API forbids the backing map from being mutated for 'a, and
// guarantees that it's the correct key and value types.
// - The thunk is safe to call as long as the iterator isn't at the end.
// - The thunk always writes to key and value fields and does not read.
// - The thunk does not increment the iterator.
unsafe {
iter.as_raw_mut(Private).next_unchecked::<Key, Self, _, _>(
|iter, key, value| Key::iter_get(iter, Self::get_prototype(), key, value),
|ffi_key| Key::to_view(ffi_key),
|value| Self::from_map_value(value),
)
}
}
}
macro_rules! impl_map_primitives {
(@impl $(($rust_type:ty, $cpp_type:ty) => [
$free_thunk:ident,
$clear_thunk:ident,
$insert_thunk:ident,
$get_thunk:ident,
$iter_get_thunk:ident,
@ -811,24 +1160,32 @@ macro_rules! impl_map_primitives {
]),* $(,)?) => {
$(
extern "C" {
pub fn $free_thunk(
m: RawMap,
prototype: MapValue,
);
pub fn $clear_thunk(
m: RawMap,
prototype: MapValue,
);
pub fn $insert_thunk(
m: RawMap,
key: $cpp_type,
value: RawMessage,
value: MapValue,
) -> bool;
pub fn $get_thunk(
m: RawMap,
prototype: RawMessage,
prototype: MapValue,
key: $cpp_type,
value: *mut RawMessage,
value: *mut MapValue,
) -> bool;
pub fn $iter_get_thunk(
iter: &mut UntypedMapIterator,
prototype: RawMessage,
prototype: MapValue,
key: *mut $cpp_type,
value: *mut RawMessage,
value: *mut MapValue,
);
pub fn $remove_thunk(m: RawMap, prototype: RawMessage, key: $cpp_type) -> bool;
pub fn $remove_thunk(m: RawMap, prototype: MapValue, key: $cpp_type) -> bool;
}
)*
};
@ -836,6 +1193,8 @@ macro_rules! impl_map_primitives {
paste!{
impl_map_primitives!(@impl $(
($rust_type, $cpp_type) => [
[< proto2_rust_map_free_ $rust_type >],
[< proto2_rust_map_clear_ $rust_type >],
[< proto2_rust_map_insert_ $rust_type >],
[< proto2_rust_map_get_ $rust_type >],
[< proto2_rust_map_iter_get_ $rust_type >],
@ -859,113 +1218,10 @@ extern "C" {
fn proto2_rust_thunk_UntypedMapIterator_increment(iter: &mut UntypedMapIterator);
pub fn proto2_rust_map_new() -> RawMap;
pub fn proto2_rust_map_free(m: RawMap, category: MapKeyCategory, prototype: RawMessage);
pub fn proto2_rust_map_clear(m: RawMap, category: MapKeyCategory, prototype: RawMessage);
pub fn proto2_rust_map_size(m: RawMap) -> usize;
pub fn proto2_rust_map_iter(m: RawMap) -> UntypedMapIterator;
}
macro_rules! impl_ProxiedInMapValue_for_non_generated_value_types {
($key_t:ty, $ffi_key_t:ty, $to_ffi_key:expr, $from_ffi_key:expr, for $($t:ty, $ffi_view_t:ty, $ffi_value_t:ty, $to_ffi_value:expr, $from_ffi_value:expr;)*) => {
paste! { $(
extern "C" {
pub fn [< proto2_rust_thunk_Map_ $key_t _ $t _new >]() -> RawMap;
pub fn [< proto2_rust_thunk_Map_ $key_t _ $t _free >](m: RawMap);
pub fn [< proto2_rust_thunk_Map_ $key_t _ $t _clear >](m: RawMap);
pub fn [< proto2_rust_thunk_Map_ $key_t _ $t _size >](m: RawMap) -> usize;
pub fn [< proto2_rust_thunk_Map_ $key_t _ $t _insert >](m: RawMap, key: $ffi_key_t, value: $ffi_value_t) -> bool;
pub fn [< proto2_rust_thunk_Map_ $key_t _ $t _get >](m: RawMap, key: $ffi_key_t, value: *mut $ffi_view_t) -> bool;
pub fn [< proto2_rust_thunk_Map_ $key_t _ $t _iter >](m: RawMap) -> UntypedMapIterator;
pub fn [< proto2_rust_thunk_Map_ $key_t _ $t _iter_get >](iter: &mut UntypedMapIterator, key: *mut $ffi_key_t, value: *mut $ffi_view_t);
pub fn [< proto2_rust_thunk_Map_ $key_t _ $t _remove >](m: RawMap, key: $ffi_key_t, value: *mut $ffi_view_t) -> bool;
}
impl ProxiedInMapValue<$key_t> for $t {
fn map_new(_private: Private) -> Map<$key_t, Self> {
unsafe {
Map::from_inner(
Private,
InnerMap {
raw: [< proto2_rust_thunk_Map_ $key_t _ $t _new >](),
}
)
}
}
unsafe fn map_free(_private: Private, map: &mut Map<$key_t, Self>) {
// SAFETY:
// - `map.inner.raw` is a live `RawMap`
// - This function is only called once for `map` in `Drop`.
unsafe { [< proto2_rust_thunk_Map_ $key_t _ $t _free >](map.as_mut().as_raw(Private)); }
}
fn map_clear(mut map: MapMut<$key_t, Self>) {
unsafe { [< proto2_rust_thunk_Map_ $key_t _ $t _clear >](map.as_raw(Private)); }
}
fn map_len(map: MapView<$key_t, Self>) -> usize {
unsafe { [< proto2_rust_thunk_Map_ $key_t _ $t _size >](map.as_raw(Private)) }
}
fn map_insert(mut map: MapMut<$key_t, Self>, key: View<'_, $key_t>, value: impl IntoProxied<Self>) -> bool {
let ffi_key = $to_ffi_key(key);
let ffi_value = $to_ffi_value(value.into_proxied(Private));
unsafe { [< proto2_rust_thunk_Map_ $key_t _ $t _insert >](map.as_raw(Private), ffi_key, ffi_value) }
}
fn map_get<'a>(map: MapView<'a, $key_t, Self>, key: View<'_, $key_t>) -> Option<View<'a, Self>> {
let ffi_key = $to_ffi_key(key);
let mut ffi_value = MaybeUninit::uninit();
let found = unsafe { [< proto2_rust_thunk_Map_ $key_t _ $t _get >](map.as_raw(Private), ffi_key, ffi_value.as_mut_ptr()) };
if !found {
return None;
}
// SAFETY: if `found` is true, then the `ffi_value` was written to by `get`.
Some($from_ffi_value(unsafe { ffi_value.assume_init() }))
}
fn map_remove(mut map: MapMut<$key_t, Self>, key: View<'_, $key_t>) -> bool {
let ffi_key = $to_ffi_key(key);
let mut ffi_value = MaybeUninit::uninit();
unsafe { [< proto2_rust_thunk_Map_ $key_t _ $t _remove >](map.as_raw(Private), ffi_key, ffi_value.as_mut_ptr()) }
}
fn map_iter(map: MapView<$key_t, Self>) -> MapIter<$key_t, Self> {
// SAFETY:
// - The backing map for `map.as_raw` is valid for at least '_.
// - A View that is live for '_ guarantees the backing map is unmodified for '_.
// - The `iter` function produces an iterator that is valid for the key
// and value types, and live for at least '_.
unsafe {
MapIter::from_raw(
Private,
[< proto2_rust_thunk_Map_ $key_t _ $t _iter >](map.as_raw(Private))
)
}
}
fn map_iter_next<'a>(iter: &mut MapIter<'a, $key_t, Self>) -> Option<(View<'a, $key_t>, View<'a, Self>)> {
// SAFETY:
// - The `MapIter` API forbids the backing map from being mutated for 'a,
// and guarantees that it's the correct key and value types.
// - The thunk is safe to call as long as the iterator isn't at the end.
// - The thunk always writes to key and value fields and does not read.
// - The thunk does not increment the iterator.
unsafe {
iter.as_raw_mut(Private).next_unchecked::<$key_t, Self, _, _>(
|iter, key, value| { [< proto2_rust_thunk_Map_ $key_t _ $t _iter_get >](iter, key, value) },
$from_ffi_key,
$from_ffi_value,
)
}
}
}
)* }
}
}
fn str_to_ptrlen<'msg>(val: impl Into<&'msg ProtoStr>) -> PtrAndLen {
val.into().as_bytes().into()
}
@ -990,36 +1246,6 @@ fn ptrlen_to_bytes<'msg>(val: PtrAndLen) -> &'msg [u8] {
unsafe { val.as_ref() }
}
macro_rules! impl_ProxiedInMapValue_for_key_types {
($($t:ty, $ffi_t:ty, $to_ffi_key:expr, $from_ffi_key:expr;)*) => {
paste! {
$(
impl_ProxiedInMapValue_for_non_generated_value_types!(
$t, $ffi_t, $to_ffi_key, $from_ffi_key, for
f32, f32, f32, identity, identity;
f64, f64, f64, identity, identity;
i32, i32, i32, identity, identity;
u32, u32, u32, identity, identity;
i64, i64, i64, identity, identity;
u64, u64, u64, identity, identity;
bool, bool, bool, identity, identity;
ProtoString, PtrAndLen, CppStdString, protostr_into_cppstdstring, ptrlen_to_str;
ProtoBytes, PtrAndLen, CppStdString, protobytes_into_cppstdstring, ptrlen_to_bytes;
);
)*
}
}
}
impl_ProxiedInMapValue_for_key_types!(
i32, i32, identity, identity;
u32, u32, identity, identity;
i64, i64, identity, identity;
u64, u64, identity, identity;
bool, bool, identity, identity;
ProtoString, PtrAndLen, str_to_ptrlen, ptrlen_to_str;
);
#[cfg(test)]
mod tests {
use super::*;

@ -15,7 +15,6 @@ cc_library(
hdrs = [
"compare.h",
"debug.h",
"map.h",
"rust_alloc_for_cpp_api.h",
"serialized_data.h",
"strings.h",

@ -1,5 +1,6 @@
#include "rust/cpp_kernel/map.h"
#include "google/protobuf/map.h"
#include <algorithm>
#include <cstddef>
#include <cstdint>
#include <string>
@ -7,7 +8,6 @@
#include <utility>
#include "absl/log/absl_log.h"
#include "google/protobuf/map.h"
#include "google/protobuf/message.h"
#include "google/protobuf/message_lite.h"
#include "rust/cpp_kernel/strings.h"
@ -17,6 +17,27 @@ namespace protobuf {
namespace rust {
namespace {
// LINT.IfChange(map_ffi)
enum class MapValueTag : uint8_t {
kBool,
kU32,
kU64,
kString,
kMessage,
};
struct MapValue {
MapValueTag tag;
union {
bool b;
uint32_t u32;
uint64_t u64;
std::string* s;
google::protobuf::MessageLite* message;
};
};
// LINT.ThenChange(//depot/google3/third_party/protobuf/rust/cpp.rs:map_ffi)
template <typename T>
struct FromViewType {
using type = T;
@ -31,41 +52,85 @@ template <typename Key>
using KeyMap = internal::KeyMapBase<
internal::KeyForBase<typename FromViewType<Key>::type>>;
internal::MapNodeSizeInfoT GetSizeInfo(size_t key_size,
const google::protobuf::MessageLite* value) {
void GetSizeAndAlignment(MapValue value, uint16_t* size, uint8_t* alignment) {
switch (value.tag) {
case MapValueTag::kBool:
*size = sizeof(bool);
*alignment = alignof(bool);
break;
case MapValueTag::kU32:
*size = sizeof(uint32_t);
*alignment = alignof(uint32_t);
break;
case MapValueTag::kU64:
*size = sizeof(uint64_t);
*alignment = alignof(uint64_t);
break;
case MapValueTag::kString:
*size = sizeof(std::string);
*alignment = alignof(std::string);
break;
case MapValueTag::kMessage:
internal::RustMapHelper::GetSizeAndAlignment(value.message, size,
alignment);
break;
default:
ABSL_DLOG(FATAL) << "Unexpected value of MapValue";
}
}
internal::MapNodeSizeInfoT GetSizeInfo(size_t key_size, MapValue value) {
// Each map node consists of a NodeBase followed by a std::pair<Key, Value>.
// We need to compute the offset of the value and the total size of the node.
size_t node_and_key_size = sizeof(internal::NodeBase) + key_size;
uint16_t value_size;
uint8_t value_alignment;
internal::RustMapHelper::GetSizeAndAlignment(value, &value_size,
&value_alignment);
GetSizeAndAlignment(value, &value_size, &value_alignment);
// Round node_and_key_size up to the nearest multiple of value_alignment.
uint16_t offset =
(((node_and_key_size - 1) / value_alignment) + 1) * value_alignment;
return internal::RustMapHelper::MakeSizeInfo(offset + value_size, offset);
}
template <typename Key>
internal::MapNodeSizeInfoT GetSizeInfo(const google::protobuf::MessageLite* value) {
return GetSizeInfo(sizeof(Key), value);
size_t overall_alignment = std::max(alignof(internal::NodeBase),
static_cast<size_t>(value_alignment));
// Round up size to nearest multiple of overall_alignment.
size_t overall_size =
(((offset + value_size - 1) / overall_alignment) + 1) * overall_alignment;
return internal::RustMapHelper::MakeSizeInfo(overall_size, offset);
}
template <typename Key>
void DestroyMapNode(internal::UntypedMapBase* m, internal::NodeBase* node,
internal::MapNodeSizeInfoT size_info) {
internal::MapNodeSizeInfoT size_info,
bool destroy_message) {
if constexpr (std::is_same<Key, PtrAndLen>::value) {
static_cast<std::string*>(node->GetVoidKey())->~basic_string();
}
internal::RustMapHelper::DestroyMessage(
static_cast<MessageLite*>(node->GetVoidValue(size_info)));
if (destroy_message) {
internal::RustMapHelper::DestroyMessage(
static_cast<MessageLite*>(node->GetVoidValue(size_info)));
}
internal::RustMapHelper::DeallocNode(m, node, size_info);
}
void InitializeMessageValue(void* raw_ptr, MessageLite* msg) {
MessageLite* new_msg = internal::RustMapHelper::PlacementNew(msg, raw_ptr);
auto* full_msg = DynamicCastMessage<Message>(new_msg);
// If we are working with a full (non-lite) proto, we reflectively swap the
// value into place. Otherwise, we have to perform a copy.
if (full_msg != nullptr) {
full_msg->GetReflection()->Swap(full_msg, DynamicCastMessage<Message>(msg));
} else {
new_msg->CheckTypeAndMergeFrom(*msg);
}
delete msg;
}
template <typename Key>
bool Insert(internal::UntypedMapBase* m, Key key, MessageLite* value) {
bool Insert(internal::UntypedMapBase* m, Key key, MapValue value) {
internal::MapNodeSizeInfoT size_info =
GetSizeInfo<typename FromViewType<Key>::type>(value);
GetSizeInfo(sizeof(typename FromViewType<Key>::type), value);
internal::NodeBase* node = internal::RustMapHelper::AllocNode(m, size_info);
if constexpr (std::is_same<Key, PtrAndLen>::value) {
new (node->GetVoidKey()) std::string(key.ptr, key.len);
@ -73,17 +138,26 @@ bool Insert(internal::UntypedMapBase* m, Key key, MessageLite* value) {
*static_cast<Key*>(node->GetVoidKey()) = key;
}
MessageLite* new_msg = internal::RustMapHelper::PlacementNew(
value, node->GetVoidValue(size_info));
auto* full_msg = DynamicCastMessage<Message>(new_msg);
// If we are working with a full (non-lite) proto, we reflectively swap the
// value into place. Otherwise, we have to perform a copy.
if (full_msg != nullptr) {
full_msg->GetReflection()->Swap(full_msg,
DynamicCastMessage<Message>(value));
} else {
new_msg->CheckTypeAndMergeFrom(*value);
void* value_ptr = node->GetVoidValue(size_info);
switch (value.tag) {
case MapValueTag::kBool:
*static_cast<bool*>(value_ptr) = value.b;
break;
case MapValueTag::kU32:
*static_cast<uint32_t*>(value_ptr) = value.u32;
break;
case MapValueTag::kU64:
*static_cast<uint64_t*>(value_ptr) = value.u64;
break;
case MapValueTag::kString:
new (value_ptr) std::string(std::move(*value.s));
delete value.s;
break;
case MapValueTag::kMessage:
InitializeMessageValue(value_ptr, value.message);
break;
default:
ABSL_DLOG(FATAL) << "Unexpected value of MapValue";
}
node = internal::RustMapHelper::InsertOrReplaceNode(
@ -91,7 +165,7 @@ bool Insert(internal::UntypedMapBase* m, Key key, MessageLite* value) {
if (node == nullptr) {
return true;
}
DestroyMapNode<Key>(m, node, size_info);
DestroyMapNode<Key>(m, node, size_info, value.tag == MapValueTag::kMessage);
return false;
}
@ -110,41 +184,63 @@ internal::RustMapHelper::NodeAndBucket FindHelper(Map* m,
m, absl::string_view(key.ptr, key.len));
}
void PopulateMapValue(MapValueTag tag, void* data, MapValue& output) {
output.tag = tag;
switch (tag) {
case MapValueTag::kBool:
output.b = *static_cast<const bool*>(data);
break;
case MapValueTag::kU32:
output.u32 = *static_cast<const uint32_t*>(data);
break;
case MapValueTag::kU64:
output.u64 = *static_cast<const uint64_t*>(data);
break;
case MapValueTag::kString:
output.s = static_cast<std::string*>(data);
break;
case MapValueTag::kMessage:
output.message = static_cast<MessageLite*>(data);
break;
default:
ABSL_DLOG(FATAL) << "Unexpected MapValueTag";
}
}
template <typename Key>
bool Get(internal::UntypedMapBase* m, const google::protobuf::MessageLite* prototype,
Key key, MessageLite** value) {
bool Get(internal::UntypedMapBase* m, MapValue prototype, Key key,
MapValue* value) {
internal::MapNodeSizeInfoT size_info =
GetSizeInfo<typename FromViewType<Key>::type>(prototype);
GetSizeInfo(sizeof(typename FromViewType<Key>::type), prototype);
auto* map_base = static_cast<KeyMap<Key>*>(m);
internal::RustMapHelper::NodeAndBucket result = FindHelper(map_base, key);
if (result.node == nullptr) {
return false;
}
*value = static_cast<MessageLite*>(result.node->GetVoidValue(size_info));
PopulateMapValue(prototype.tag, result.node->GetVoidValue(size_info), *value);
return true;
}
template <typename Key>
bool Remove(internal::UntypedMapBase* m, const google::protobuf::MessageLite* prototype,
Key key) {
bool Remove(internal::UntypedMapBase* m, MapValue prototype, Key key) {
internal::MapNodeSizeInfoT size_info =
GetSizeInfo<typename FromViewType<Key>::type>(prototype);
GetSizeInfo(sizeof(typename FromViewType<Key>::type), prototype);
auto* map_base = static_cast<KeyMap<Key>*>(m);
internal::RustMapHelper::NodeAndBucket result = FindHelper(map_base, key);
if (result.node == nullptr) {
return false;
}
internal::RustMapHelper::EraseNoDestroy(map_base, result.bucket, result.node);
DestroyMapNode<Key>(m, result.node, size_info);
DestroyMapNode<Key>(m, result.node, size_info,
prototype.tag == MapValueTag::kMessage);
return true;
}
template <typename Key>
void IterGet(const internal::UntypedMapIterator* iter,
const google::protobuf::MessageLite* prototype, Key* key,
MessageLite** value) {
void IterGet(const internal::UntypedMapIterator* iter, MapValue prototype,
Key* key, MapValue* value) {
internal::MapNodeSizeInfoT size_info =
GetSizeInfo<typename FromViewType<Key>::type>(prototype);
GetSizeInfo(sizeof(typename FromViewType<Key>::type), prototype);
internal::NodeBase* node = iter->node_;
if constexpr (std::is_same<Key, PtrAndLen>::value) {
const std::string* s = static_cast<const std::string*>(node->GetVoidKey());
@ -152,42 +248,35 @@ void IterGet(const internal::UntypedMapIterator* iter,
} else {
*key = *static_cast<const Key*>(node->GetVoidKey());
}
*value = static_cast<MessageLite*>(node->GetVoidValue(size_info));
PopulateMapValue(prototype.tag, node->GetVoidValue(size_info), *value);
}
// LINT.IfChange(map_key_category)
enum class MapKeyCategory : uint8_t {
kOneByte = 0,
kFourBytes = 1,
kEightBytes = 2,
kStdString = 3,
};
// LINT.ThenChange(//depot/google3/third_party/protobuf/rust/cpp.rs:map_key_category)
size_t KeySize(MapKeyCategory category) {
switch (category) {
case MapKeyCategory::kOneByte:
return 1;
case MapKeyCategory::kFourBytes:
return 4;
case MapKeyCategory::kEightBytes:
return 8;
case MapKeyCategory::kStdString:
return sizeof(std::string);
default:
ABSL_DLOG(FATAL) << "Unexpected value of MapKeyCategory enum";
// Returns the size of the key in the map entry, given the key used for FFI.
// The map entry key and FFI key are always the same, except in the case of
// string and bytes.
template <typename Key>
size_t KeySize() {
if constexpr (std::is_same<Key, google::protobuf::rust::PtrAndLen>::value) {
return sizeof(std::string);
} else {
return sizeof(Key);
}
}
void ClearMap(internal::UntypedMapBase* m, MapKeyCategory category,
bool reset_table, const google::protobuf::MessageLite* prototype) {
internal::MapNodeSizeInfoT size_info =
GetSizeInfo(KeySize(category), prototype);
template <typename Key>
void ClearMap(internal::UntypedMapBase* m, bool reset_table,
MapValue prototype) {
internal::MapNodeSizeInfoT size_info = GetSizeInfo(KeySize<Key>(), prototype);
if (internal::RustMapHelper::IsGlobalEmptyTable(m)) return;
uint8_t bits = internal::RustMapHelper::kValueIsProto;
if (category == MapKeyCategory::kStdString) {
uint8_t bits = 0;
if constexpr (std::is_same<Key, google::protobuf::rust::PtrAndLen>::value) {
bits |= internal::RustMapHelper::kKeyIsString;
}
if (prototype.tag == MapValueTag::kString) {
bits |= internal::RustMapHelper::kValueIsString;
} else if (prototype.tag == MapValueTag::kMessage) {
bits |= internal::RustMapHelper::kValueIsProto;
}
internal::RustMapHelper::ClearTable(
m, internal::RustMapHelper::ClearInput{size_info, bits, reset_table,
/* destroy_node = */ nullptr});
@ -209,19 +298,6 @@ google::protobuf::internal::UntypedMapBase* proto2_rust_map_new() {
return new google::protobuf::internal::UntypedMapBase(/* arena = */ nullptr);
}
void proto2_rust_map_free(google::protobuf::internal::UntypedMapBase* m,
google::protobuf::rust::MapKeyCategory category,
const google::protobuf::MessageLite* prototype) {
google::protobuf::rust::ClearMap(m, category, /* reset_table = */ false, prototype);
delete m;
}
void proto2_rust_map_clear(google::protobuf::internal::UntypedMapBase* m,
google::protobuf::rust::MapKeyCategory category,
const google::protobuf::MessageLite* prototype) {
google::protobuf::rust::ClearMap(m, category, /* reset_table = */ true, prototype);
}
size_t proto2_rust_map_size(google::protobuf::internal::UntypedMapBase* m) {
return m->size();
}
@ -231,31 +307,39 @@ google::protobuf::internal::UntypedMapIterator proto2_rust_map_iter(
return m->begin();
}
#define DEFINE_KEY_SPECIFIC_MAP_OPERATIONS(cpp_type, suffix) \
bool proto2_rust_map_insert_##suffix(google::protobuf::internal::UntypedMapBase* m, \
cpp_type key, \
google::protobuf::MessageLite* value) { \
return google::protobuf::rust::Insert(m, key, value); \
} \
\
bool proto2_rust_map_get_##suffix(google::protobuf::internal::UntypedMapBase* m, \
const google::protobuf::MessageLite* prototype, \
cpp_type key, \
google::protobuf::MessageLite** value) { \
return google::protobuf::rust::Get(m, prototype, key, value); \
} \
\
bool proto2_rust_map_remove_##suffix(google::protobuf::internal::UntypedMapBase* m, \
const google::protobuf::MessageLite* prototype, \
cpp_type key) { \
return google::protobuf::rust::Remove(m, prototype, key); \
} \
\
void proto2_rust_map_iter_get_##suffix( \
const google::protobuf::internal::UntypedMapIterator* iter, \
const google::protobuf::MessageLite* prototype, cpp_type* key, \
google::protobuf::MessageLite** value) { \
return google::protobuf::rust::IterGet(iter, prototype, key, value); \
#define DEFINE_KEY_SPECIFIC_MAP_OPERATIONS(cpp_type, suffix) \
void proto2_rust_map_free_##suffix(google::protobuf::internal::UntypedMapBase* m, \
google::protobuf::rust::MapValue prototype) { \
google::protobuf::rust::ClearMap<cpp_type>(m, /* reset_table = */ false, prototype); \
delete m; \
} \
void proto2_rust_map_clear_##suffix(google::protobuf::internal::UntypedMapBase* m, \
google::protobuf::rust::MapValue prototype) { \
google::protobuf::rust::ClearMap<cpp_type>(m, /* reset_table = */ true, prototype); \
} \
bool proto2_rust_map_insert_##suffix(google::protobuf::internal::UntypedMapBase* m, \
cpp_type key, \
google::protobuf::rust::MapValue value) { \
return google::protobuf::rust::Insert(m, key, value); \
} \
\
bool proto2_rust_map_get_##suffix( \
google::protobuf::internal::UntypedMapBase* m, google::protobuf::rust::MapValue prototype, \
cpp_type key, google::protobuf::rust::MapValue* value) { \
return google::protobuf::rust::Get(m, prototype, key, value); \
} \
\
bool proto2_rust_map_remove_##suffix(google::protobuf::internal::UntypedMapBase* m, \
google::protobuf::rust::MapValue prototype, \
cpp_type key) { \
return google::protobuf::rust::Remove(m, prototype, key); \
} \
\
void proto2_rust_map_iter_get_##suffix( \
const google::protobuf::internal::UntypedMapIterator* iter, \
google::protobuf::rust::MapValue prototype, cpp_type* key, \
google::protobuf::rust::MapValue* value) { \
return google::protobuf::rust::IterGet(iter, prototype, key, value); \
}
DEFINE_KEY_SPECIFIC_MAP_OPERATIONS(int32_t, i32)
@ -265,27 +349,4 @@ DEFINE_KEY_SPECIFIC_MAP_OPERATIONS(uint64_t, u64)
DEFINE_KEY_SPECIFIC_MAP_OPERATIONS(bool, bool)
DEFINE_KEY_SPECIFIC_MAP_OPERATIONS(google::protobuf::rust::PtrAndLen, ProtoString)
__PB_RUST_EXPOSE_SCALAR_MAP_METHODS_FOR_VALUE_TYPE(int32_t, i32, int32_t,
int32_t, value, cpp_value);
__PB_RUST_EXPOSE_SCALAR_MAP_METHODS_FOR_VALUE_TYPE(uint32_t, u32, uint32_t,
uint32_t, value, cpp_value);
__PB_RUST_EXPOSE_SCALAR_MAP_METHODS_FOR_VALUE_TYPE(float, f32, float, float,
value, cpp_value);
__PB_RUST_EXPOSE_SCALAR_MAP_METHODS_FOR_VALUE_TYPE(double, f64, double, double,
value, cpp_value);
__PB_RUST_EXPOSE_SCALAR_MAP_METHODS_FOR_VALUE_TYPE(bool, bool, bool, bool,
value, cpp_value);
__PB_RUST_EXPOSE_SCALAR_MAP_METHODS_FOR_VALUE_TYPE(uint64_t, u64, uint64_t,
uint64_t, value, cpp_value);
__PB_RUST_EXPOSE_SCALAR_MAP_METHODS_FOR_VALUE_TYPE(int64_t, i64, int64_t,
int64_t, value, cpp_value);
__PB_RUST_EXPOSE_SCALAR_MAP_METHODS_FOR_VALUE_TYPE(
std::string, ProtoBytes, google::protobuf::rust::PtrAndLen, std::string*,
std::move(*value),
(google::protobuf::rust::PtrAndLen{cpp_value.data(), cpp_value.size()}));
__PB_RUST_EXPOSE_SCALAR_MAP_METHODS_FOR_VALUE_TYPE(
std::string, ProtoString, google::protobuf::rust::PtrAndLen, std::string*,
std::move(*value),
(google::protobuf::rust::PtrAndLen{cpp_value.data(), cpp_value.size()}));
} // extern "C"

@ -1,122 +0,0 @@
#ifndef GOOGLE_PROTOBUF_RUST_CPP_KERNEL_MAP_H__
#define GOOGLE_PROTOBUF_RUST_CPP_KERNEL_MAP_H__
#include <memory>
#include <type_traits>
#include "google/protobuf/map.h"
#include "google/protobuf/message_lite.h"
#include "rust/cpp_kernel/strings.h"
namespace google {
namespace protobuf {
namespace rust {
// String and bytes values are passed across the FFI boundary as owned raw
// pointers when we do map insertions. Unlike other types, they have to be
// explicitly deleted. This MakeCleanup() helper does nothing by default, but
// for std::string pointers it returns a std::unique_ptr to take ownership of
// the raw pointer.
template <typename T>
auto MakeCleanup(T value) {
if constexpr (std::is_same<T, std::string*>::value) {
return std::unique_ptr<std::string>(value);
} else {
return 0;
}
}
} // namespace rust
} // namespace protobuf
} // namespace google
// Defines concrete thunks to access typed map methods from Rust.
#define __PB_RUST_EXPOSE_SCALAR_MAP_METHODS( \
key_ty, rust_key_ty, ffi_key_ty, to_cpp_key, to_ffi_key, value_ty, \
rust_value_ty, ffi_view_ty, ffi_value_ty, to_cpp_value, to_ffi_value) \
google::protobuf::Map<key_ty, value_ty>* \
proto2_rust_thunk_Map_##rust_key_ty##_##rust_value_ty##_new() { \
return new google::protobuf::Map<key_ty, value_ty>(); \
} \
void proto2_rust_thunk_Map_##rust_key_ty##_##rust_value_ty##_free( \
google::protobuf::Map<key_ty, value_ty>* m) { \
delete m; \
} \
void proto2_rust_thunk_Map_##rust_key_ty##_##rust_value_ty##_clear( \
google::protobuf::Map<key_ty, value_ty>* m) { \
m->clear(); \
} \
size_t proto2_rust_thunk_Map_##rust_key_ty##_##rust_value_ty##_size( \
const google::protobuf::Map<key_ty, value_ty>* m) { \
return m->size(); \
} \
bool proto2_rust_thunk_Map_##rust_key_ty##_##rust_value_ty##_insert( \
google::protobuf::Map<key_ty, value_ty>* m, ffi_key_ty key, ffi_value_ty value) { \
auto cleanup = google::protobuf::rust::MakeCleanup(value); \
(void)cleanup; \
auto iter_and_inserted = m->try_emplace(to_cpp_key, to_cpp_value); \
if (!iter_and_inserted.second) { \
iter_and_inserted.first->second = to_cpp_value; \
} \
return iter_and_inserted.second; \
} \
bool proto2_rust_thunk_Map_##rust_key_ty##_##rust_value_ty##_get( \
const google::protobuf::Map<key_ty, value_ty>* m, ffi_key_ty key, \
ffi_view_ty* value) { \
auto cpp_key = to_cpp_key; \
auto it = m->find(cpp_key); \
if (it == m->end()) { \
return false; \
} \
auto& cpp_value = it->second; \
*value = to_ffi_value; \
return true; \
} \
google::protobuf::internal::UntypedMapIterator \
proto2_rust_thunk_Map_##rust_key_ty##_##rust_value_ty##_iter( \
const google::protobuf::Map<key_ty, value_ty>* m) { \
return google::protobuf::internal::UntypedMapIterator::FromTyped(m->cbegin()); \
} \
void proto2_rust_thunk_Map_##rust_key_ty##_##rust_value_ty##_iter_get( \
const google::protobuf::internal::UntypedMapIterator* iter, ffi_key_ty* key, \
ffi_view_ty* value) { \
auto typed_iter = \
iter->ToTyped<google::protobuf::Map<key_ty, value_ty>::const_iterator>(); \
const auto& cpp_key = typed_iter->first; \
const auto& cpp_value = typed_iter->second; \
*key = to_ffi_key; \
*value = to_ffi_value; \
} \
bool proto2_rust_thunk_Map_##rust_key_ty##_##rust_value_ty##_remove( \
google::protobuf::Map<key_ty, value_ty>* m, ffi_key_ty key, ffi_view_ty* value) { \
auto cpp_key = to_cpp_key; \
auto num_removed = m->erase(cpp_key); \
return num_removed > 0; \
}
// Defines the map thunks for all supported key types for a given value type.
#define __PB_RUST_EXPOSE_SCALAR_MAP_METHODS_FOR_VALUE_TYPE( \
value_ty, rust_value_ty, ffi_view_ty, ffi_value_ty, to_cpp_value, \
to_ffi_value) \
__PB_RUST_EXPOSE_SCALAR_MAP_METHODS( \
int32_t, i32, int32_t, key, cpp_key, value_ty, rust_value_ty, \
ffi_view_ty, ffi_value_ty, to_cpp_value, to_ffi_value); \
__PB_RUST_EXPOSE_SCALAR_MAP_METHODS( \
uint32_t, u32, uint32_t, key, cpp_key, value_ty, rust_value_ty, \
ffi_view_ty, ffi_value_ty, to_cpp_value, to_ffi_value); \
__PB_RUST_EXPOSE_SCALAR_MAP_METHODS( \
bool, bool, bool, key, cpp_key, value_ty, rust_value_ty, ffi_view_ty, \
ffi_value_ty, to_cpp_value, to_ffi_value); \
__PB_RUST_EXPOSE_SCALAR_MAP_METHODS( \
uint64_t, u64, uint64_t, key, cpp_key, value_ty, rust_value_ty, \
ffi_view_ty, ffi_value_ty, to_cpp_value, to_ffi_value); \
__PB_RUST_EXPOSE_SCALAR_MAP_METHODS( \
int64_t, i64, int64_t, key, cpp_key, value_ty, rust_value_ty, \
ffi_view_ty, ffi_value_ty, to_cpp_value, to_ffi_value); \
__PB_RUST_EXPOSE_SCALAR_MAP_METHODS( \
std::string, ProtoString, google::protobuf::rust::PtrAndLen, \
std::string(key.ptr, key.len), \
(google::protobuf::rust::PtrAndLen{cpp_key.data(), cpp_key.size()}), value_ty, \
rust_value_ty, ffi_view_ty, ffi_value_ty, to_cpp_value, to_ffi_value);
#endif // GOOGLE_PROTOBUF_RUST_CPP_KERNEL_MAP_H__

@ -45,106 +45,26 @@ std::vector<std::pair<absl::string_view, int32_t>> EnumValuesInput(
return result;
}
void EnumProxiedInMapValue(Context& ctx, const EnumDescriptor& desc) {
void TypeConversions(Context& ctx, const EnumDescriptor& desc) {
switch (ctx.opts().kernel) {
case Kernel::kCpp:
for (const auto& t : kMapKeyTypes) {
ctx.Emit(
{{"map_new_thunk", RawMapThunk(ctx, desc, t.thunk_ident, "new")},
{"map_free_thunk", RawMapThunk(ctx, desc, t.thunk_ident, "free")},
{"map_clear_thunk",
RawMapThunk(ctx, desc, t.thunk_ident, "clear")},
{"map_size_thunk", RawMapThunk(ctx, desc, t.thunk_ident, "size")},
{"map_insert_thunk",
RawMapThunk(ctx, desc, t.thunk_ident, "insert")},
{"map_get_thunk", RawMapThunk(ctx, desc, t.thunk_ident, "get")},
{"map_remove_thunk",
RawMapThunk(ctx, desc, t.thunk_ident, "remove")},
{"map_iter_thunk", RawMapThunk(ctx, desc, t.thunk_ident, "iter")},
{"map_iter_get_thunk",
RawMapThunk(ctx, desc, t.thunk_ident, "iter_get")},
{"to_ffi_key_expr", t.rs_to_ffi_key_expr},
io::Printer::Sub("ffi_key_t", [&] { ctx.Emit(t.rs_ffi_key_t); })
.WithSuffix(""),
io::Printer::Sub("key_t", [&] { ctx.Emit(t.rs_key_t); })
.WithSuffix(""),
io::Printer::Sub("from_ffi_key_expr",
[&] { ctx.Emit(t.rs_from_ffi_key_expr); })
.WithSuffix("")},
R"rs(
impl $pb$::ProxiedInMapValue<$key_t$> for $name$ {
fn map_new(_private: $pbi$::Private) -> $pb$::Map<$key_t$, Self> {
unsafe {
$pb$::Map::from_inner(
$pbi$::Private,
$pbr$::InnerMap::new($pbr$::$map_new_thunk$())
)
}
}
unsafe fn map_free(_private: $pbi$::Private, map: &mut $pb$::Map<$key_t$, Self>) {
unsafe { $pbr$::$map_free_thunk$(map.as_raw($pbi$::Private)); }
}
fn map_clear(mut map: $pb$::MapMut<$key_t$, Self>) {
unsafe { $pbr$::$map_clear_thunk$(map.as_raw($pbi$::Private)); }
}
fn map_len(map: $pb$::MapView<$key_t$, Self>) -> usize {
unsafe { $pbr$::$map_size_thunk$(map.as_raw($pbi$::Private)) }
}
fn map_insert(mut map: $pb$::MapMut<$key_t$, Self>, key: $pb$::View<'_, $key_t$>, value: impl $pb$::IntoProxied<Self>) -> bool {
unsafe { $pbr$::$map_insert_thunk$(map.as_raw($pbi$::Private), $to_ffi_key_expr$, value.into_proxied($pbi$::Private).0) }
}
fn map_get<'a>(map: $pb$::MapView<'a, $key_t$, Self>, key: $pb$::View<'_, $key_t$>) -> $Option$<$pb$::View<'a, Self>> {
let key = $to_ffi_key_expr$;
let mut value = $std$::mem::MaybeUninit::uninit();
let found = unsafe { $pbr$::$map_get_thunk$(map.as_raw($pbi$::Private), key, value.as_mut_ptr()) };
if !found {
return None;
}
Some(unsafe { $name$(value.assume_init()) })
}
fn map_remove(mut map: $pb$::MapMut<$key_t$, Self>, key: $pb$::View<'_, $key_t$>) -> bool {
let mut value = $std$::mem::MaybeUninit::uninit();
unsafe { $pbr$::$map_remove_thunk$(map.as_raw($pbi$::Private), $to_ffi_key_expr$, value.as_mut_ptr()) }
}
ctx.Emit(
R"rs(
impl $pbr$::CppMapTypeConversions for $name$ {
fn get_prototype() -> $pbr$::MapValue {
Self::to_map_value(Self::default())
}
fn map_iter(map: $pb$::MapView<$key_t$, Self>) -> $pb$::MapIter<$key_t$, Self> {
// SAFETY:
// - The backing map for `map.as_raw` is valid for at least '_.
// - A View that is live for '_ guarantees the backing map is unmodified for '_.
// - The `iter` function produces an iterator that is valid for the key
// and value types, and live for at least '_.
unsafe {
$pb$::MapIter::from_raw(
$pbi$::Private,
$pbr$::$map_iter_thunk$(map.as_raw($pbi$::Private))
)
}
}
fn to_map_value(self) -> $pbr$::MapValue {
$pbr$::MapValue::make_u32(self.0 as u32)
}
fn map_iter_next<'a>(iter: &mut $pb$::MapIter<'a, $key_t$, Self>) -> $Option$<($pb$::View<'a, $key_t$>, $pb$::View<'a, Self>)> {
// SAFETY:
// - The `MapIter` API forbids the backing map from being mutated for 'a,
// and guarantees that it's the correct key and value types.
// - The thunk is safe to call as long as the iterator isn't at the end.
// - The thunk always writes to key and value fields and does not read.
// - The thunk does not increment the iterator.
unsafe {
iter.as_raw_mut($pbi$::Private).next_unchecked::<$key_t$, Self, _, _>(
|iter, key, value| { $pbr$::$map_iter_get_thunk$(iter, key, value) },
|ffi_key| $from_ffi_key_expr$,
|value| $name$(value),
)
}
}
}
)rs");
}
unsafe fn from_map_value<'a>(value: $pbr$::MapValue) -> $pb$::View<'a, Self> {
debug_assert_eq!(value.tag, $pbr$::MapValueTag::U32);
$name$(unsafe { value.val.u as i32 })
}
}
)rs");
return;
case Kernel::kUpb:
ctx.Emit(R"rs(
@ -277,7 +197,7 @@ void GenerateEnumDefinition(Context& ctx, const EnumDescriptor& desc) {
)rs");
}
}},
{"impl_proxied_in_map", [&] { EnumProxiedInMapValue(ctx, desc); }},
{"type_conversions_impl", [&] { TypeConversions(ctx, desc); }},
},
R"rs(
#[repr(transparent)]
@ -411,7 +331,7 @@ void GenerateEnumDefinition(Context& ctx, const EnumDescriptor& desc) {
}
}
$impl_proxied_in_map$
$type_conversions_impl$
)rs");
}

@ -225,7 +225,6 @@ bool RustGenerator::Generate(const FileDescriptor* file,
#include "google/protobuf/map.h"
#include "google/protobuf/repeated_field.h"
#include "google/protobuf/repeated_ptr_field.h"
#include "rust/cpp_kernel/map.h"
#include "rust/cpp_kernel/serialized_data.h"
#include "rust/cpp_kernel/strings.h"
)cc");

@ -573,120 +573,29 @@ void MessageProxiedInRepeated(Context& ctx, const Descriptor& msg) {
ABSL_LOG(FATAL) << "unreachable";
}
void MessageProxiedInMapValue(Context& ctx, const Descriptor& msg) {
void TypeConversions(Context& ctx, const Descriptor& msg) {
switch (ctx.opts().kernel) {
case Kernel::kCpp:
for (const auto& t : kMapKeyTypes) {
ctx.Emit(
{{"map_insert",
absl::StrCat("proto2_rust_map_insert_", t.thunk_ident)},
{"map_remove",
absl::StrCat("proto2_rust_map_remove_", t.thunk_ident)},
{"map_get", absl::StrCat("proto2_rust_map_get_", t.thunk_ident)},
{"map_iter_get",
absl::StrCat("proto2_rust_map_iter_get_", t.thunk_ident)},
{"key_expr", t.rs_to_ffi_key_expr},
{"key_is_string",
t.thunk_ident == "ProtoString" ? "true" : "false"},
io::Printer::Sub("ffi_key_t", [&] { ctx.Emit(t.rs_ffi_key_t); })
.WithSuffix(""),
io::Printer::Sub("key_t", [&] { ctx.Emit(t.rs_key_t); })
.WithSuffix(""),
io::Printer::Sub("from_ffi_key_expr",
[&] { ctx.Emit(t.rs_from_ffi_key_expr); })
.WithSuffix("")},
R"rs(
impl $pb$::ProxiedInMapValue<$key_t$> for $Msg$ {
fn map_new(_private: $pbi$::Private) -> $pb$::Map<$key_t$, Self> {
unsafe {
$pb$::Map::from_inner(
$pbi$::Private,
$pbr$::InnerMap::new($pbr$::proto2_rust_map_new())
)
}
}
unsafe fn map_free(_private: $pbi$::Private, map: &mut $pb$::Map<$key_t$, Self>) {
use $pbr$::MapKey;
unsafe { $pbr$::proto2_rust_map_free(map.as_raw($pbi$::Private), $key_t$::CATEGORY, <$pb$::View::<$Msg$> as std::default::Default>::default().raw_msg()); }
}
fn map_clear(mut map: $pb$::MapMut<$key_t$, Self>) {
use $pbr$::MapKey;
unsafe { $pbr$::proto2_rust_map_clear(map.as_raw($pbi$::Private), $key_t$::CATEGORY, <$pb$::View::<$Msg$> as std::default::Default>::default().raw_msg()); }
}
fn map_len(map: $pb$::MapView<$key_t$, Self>) -> usize {
unsafe { $pbr$::proto2_rust_map_size(map.as_raw($pbi$::Private)) }
}
fn map_insert(mut map: $pb$::MapMut<$key_t$, Self>, key: $pb$::View<'_, $key_t$>, value: impl $pb$::IntoProxied<Self>) -> bool {
unsafe {
$pbr$::$map_insert$(
map.as_raw($pbi$::Private),
$key_expr$,
value.into_proxied($pbi$::Private).raw_msg())
}
}
fn map_get<'a>(map: $pb$::MapView<'a, $key_t$, Self>, key: $pb$::View<'_, $key_t$>) -> $Option$<$pb$::View<'a, Self>> {
let key = $key_expr$;
let mut value = $std$::mem::MaybeUninit::uninit();
let found = unsafe {
$pbr$::$map_get$(
map.as_raw($pbi$::Private),
<$pb$::View::<$Msg$> as std::default::Default>::default().raw_msg(),
key,
value.as_mut_ptr())
};
if !found {
return None;
}
Some($Msg$View::new($pbi$::Private, unsafe { value.assume_init() }))
}
fn map_remove(mut map: $pb$::MapMut<$key_t$, Self>, key: $pb$::View<'_, $key_t$>) -> bool {
unsafe {
$pbr$::$map_remove$(
map.as_raw($pbi$::Private),
<$pb$::View::<$Msg$> as std::default::Default>::default().raw_msg(),
$key_expr$)
}
}
ctx.Emit(
R"rs(
impl $pbr$::CppMapTypeConversions for $Msg$ {
fn get_prototype() -> $pbr$::MapValue {
$pbr$::MapValue::make_message(<$Msg$View as $std$::default::Default>::default().raw_msg())
}
fn map_iter(map: $pb$::MapView<$key_t$, Self>) -> $pb$::MapIter<$key_t$, Self> {
// SAFETY:
// - The backing map for `map.as_raw` is valid for at least '_.
// - A View that is live for '_ guarantees the backing map is unmodified for '_.
// - The `iter` function produces an iterator that is valid for the key
// and value types, and live for at least '_.
unsafe {
$pb$::MapIter::from_raw(
$pbi$::Private,
$pbr$::proto2_rust_map_iter(map.as_raw($pbi$::Private))
)
}
}
fn to_map_value(self) -> $pbr$::MapValue {
use $pb$::OwnedMessageInterop;
$pbr$::MapValue::make_message(unsafe {
$NonNull$::new_unchecked(self.__unstable_leak_raw_message() as *mut _)
})
}
fn map_iter_next<'a>(iter: &mut $pb$::MapIter<'a, $key_t$, Self>) -> $Option$<($pb$::View<'a, $key_t$>, $pb$::View<'a, Self>)> {
// SAFETY:
// - The `MapIter` API forbids the backing map from being mutated for 'a,
// and guarantees that it's the correct key and value types.
// - The thunk is safe to call as long as the iterator isn't at the end.
// - The thunk always writes to key and value fields and does not read.
// - The thunk does not increment the iterator.
unsafe {
iter.as_raw_mut($pbi$::Private).next_unchecked::<$key_t$, Self, _, _>(
|iter, key, value| { $pbr$::$map_iter_get$(
iter, <$pb$::View::<$Msg$> as std::default::Default>::default().raw_msg(), key, value) },
|ffi_key| $from_ffi_key_expr$,
|raw_msg| $Msg$View::new($pbi$::Private, raw_msg)
)
}
}
}
)rs");
}
unsafe fn from_map_value<'b>(value: $pbr$::MapValue) -> $Msg$View<'b> {
debug_assert_eq!(value.tag, $pbr$::MapValueTag::Message);
unsafe { $Msg$View::new($pbi$::Private, value.val.m) }
}
}
)rs");
return;
case Kernel::kUpb:
ctx.Emit(
@ -845,7 +754,7 @@ void GenerateRs(Context& ctx, const Descriptor& msg) {
{"upb_generated_message_trait_impls",
[&] { UpbGeneratedMessageTraitImpls(ctx, msg); }},
{"repeated_impl", [&] { MessageProxiedInRepeated(ctx, msg); }},
{"map_value_impl", [&] { MessageProxiedInMapValue(ctx, msg); }},
{"type_conversions_impl", [&] { TypeConversions(ctx, msg); }},
{"unwrap_upb",
[&] {
if (ctx.is_upb()) {
@ -1008,7 +917,7 @@ void GenerateRs(Context& ctx, const Descriptor& msg) {
$into_proxied_impl$
$repeated_impl$
$map_value_impl$
$type_conversions_impl$
#[allow(dead_code)]
#[allow(non_camel_case_types)]

@ -405,44 +405,6 @@ absl::string_view MultiCasePrefixStripper::StripPrefix(
return name;
}
PROTOBUF_CONSTINIT const MapKeyType kMapKeyTypes[] = {
{/*thunk_ident=*/"i32", /*rs_key_t=*/"i32", /*rs_ffi_key_t=*/"i32",
/*rs_to_ffi_key_expr=*/"key", /*rs_from_ffi_key_expr=*/"ffi_key",
/*cc_key_t=*/"int32_t", /*cc_ffi_key_t=*/"int32_t",
/*cc_from_ffi_key_expr=*/"key",
/*cc_to_ffi_key_expr=*/"cpp_key"},
{/*thunk_ident=*/"u32", /*rs_key_t=*/"u32", /*rs_ffi_key_t=*/"u32",
/*rs_to_ffi_key_expr=*/"key", /*rs_from_ffi_key_expr=*/"ffi_key",
/*cc_key_t=*/"uint32_t", /*cc_ffi_key_t=*/"uint32_t",
/*cc_from_ffi_key_expr=*/"key",
/*cc_to_ffi_key_expr=*/"cpp_key"},
{/*thunk_ident=*/"i64", /*rs_key_t=*/"i64", /*rs_ffi_key_t=*/"i64",
/*rs_to_ffi_key_expr=*/"key", /*rs_from_ffi_key_expr=*/"ffi_key",
/*cc_key_t=*/"int64_t", /*cc_ffi_key_t=*/"int64_t",
/*cc_from_ffi_key_expr=*/"key",
/*cc_to_ffi_key_expr=*/"cpp_key"},
{/*thunk_ident=*/"u64", /*rs_key_t=*/"u64", /*rs_ffi_key_t=*/"u64",
/*rs_to_ffi_key_expr=*/"key", /*rs_from_ffi_key_expr=*/"ffi_key",
/*cc_key_t=*/"uint64_t", /*cc_ffi_key_t=*/"uint64_t",
/*cc_from_ffi_key_expr=*/"key",
/*cc_to_ffi_key_expr=*/"cpp_key"},
{/*thunk_ident=*/"bool", /*rs_key_t=*/"bool", /*rs_ffi_key_t=*/"bool",
/*rs_to_ffi_key_expr=*/"key", /*rs_from_ffi_key_expr=*/"ffi_key",
/*cc_key_t=*/"bool", /*cc_ffi_key_t=*/"bool",
/*cc_from_ffi_key_expr=*/"key",
/*cc_to_ffi_key_expr=*/"cpp_key"},
{/*thunk_ident=*/"ProtoString",
/*rs_key_t=*/"$pb$::ProtoString",
/*rs_ffi_key_t=*/"$pbr$::PtrAndLen",
/*rs_to_ffi_key_expr=*/"key.as_bytes().into()",
/*rs_from_ffi_key_expr=*/
"$pb$::ProtoStr::from_utf8_unchecked(ffi_key.as_ref())",
/*cc_key_t=*/"std::string",
/*cc_ffi_key_t=*/"google::protobuf::rust::PtrAndLen",
/*cc_from_ffi_key_expr=*/
"std::string(key.ptr, key.len)", /*cc_to_ffi_key_expr=*/
"google::protobuf::rust::PtrAndLen{cpp_key.data(), cpp_key.size()}"}};
} // namespace rust
} // namespace compiler
} // namespace protobuf

@ -1194,6 +1194,7 @@ class RustMapHelper {
enum {
kKeyIsString = UntypedMapBase::kKeyIsString,
kValueIsString = UntypedMapBase::kValueIsString,
kValueIsProto = UntypedMapBase::kValueIsProto,
};

Loading…
Cancel
Save