Implement support for messages as map values

PiperOrigin-RevId: 605581725
pull/15719/head
Marcel Hlopko 10 months ago committed by Copybara-Service
parent 7da29c6a13
commit 5842cc9c3d
  1. 7
      rust/cpp.rs
  2. 14
      rust/map.rs
  3. 2
      rust/test/shared/BUILD
  4. 153
      rust/test/shared/accessors_map_test.rs
  5. 39
      rust/upb.rs
  6. 4
      src/google/protobuf/compiler/rust/accessors/accessors.cc
  7. 17
      src/google/protobuf/compiler/rust/accessors/map.cc
  8. 439
      src/google/protobuf/compiler/rust/message.cc
  9. 6
      src/google/protobuf/compiler/rust/naming.cc
  10. 2
      src/google/protobuf/compiler/rust/naming.h
  11. 15
      src/google/protobuf/map_unittest.proto

@ -402,10 +402,16 @@ pub struct InnerMapMut<'msg> {
_phantom: PhantomData<&'msg ()>,
}
#[doc(hidden)]
impl<'msg> InnerMapMut<'msg> {
pub fn new(_private: Private, raw: RawMap) -> Self {
InnerMapMut { raw, _phantom: PhantomData }
}
#[doc(hidden)]
pub fn as_raw(&self, _private: Private) -> RawMap {
self.raw
}
}
/// An untyped iterator in a map, produced via `.cbegin()` on a typed map.
@ -547,6 +553,7 @@ macro_rules! impl_ProxiedInMapValue_for_non_generated_value_types {
let ffi_key = $to_ffi_key(key);
let mut ffi_value = MaybeUninit::uninit();
let found = unsafe { [< __rust_proto_thunk__Map_ $key_t _ $t _get >](map.as_raw(Private), ffi_key, ffi_value.as_mut_ptr()) };
if !found {
return None;
}

@ -42,6 +42,12 @@ pub struct MapMut<'msg, K: ?Sized, V: ?Sized> {
_phantom: PhantomData<(&'msg mut K, &'msg mut V)>,
}
impl<'msg, K: ?Sized, V: ?Sized> MapMut<'msg, K, V> {
pub fn inner(&self, _private: Private) -> InnerMapMut {
self.inner
}
}
unsafe impl<'msg, K: ?Sized, V: ?Sized> Sync for MapMut<'msg, K, V> {}
impl<'msg, K: ?Sized, V: ?Sized> std::fmt::Debug for MapMut<'msg, K, V> {
@ -178,6 +184,14 @@ where
pub unsafe fn from_inner(_private: Private, inner: InnerMapMut<'static>) -> Self {
Self { inner, _phantom: PhantomData }
}
pub fn as_raw(&self, _private: Private) -> RawMap {
self.inner.as_raw(Private)
}
pub fn inner(&self, _private: Private) -> InnerMapMut<'static> {
self.inner
}
}
#[doc(hidden)]

@ -446,6 +446,7 @@ rust_test(
],
deps = [
"//rust/test:map_unittest_cc_rust_proto",
"//rust/test:unittest_cc_rust_proto",
"@crate_index//:googletest",
],
)
@ -462,6 +463,7 @@ rust_test(
],
deps = [
"//rust/test:map_unittest_upb_rust_proto",
"//rust/test:unittest_upb_rust_proto",
"@crate_index//:googletest",
],
)

@ -6,9 +6,10 @@
// https://developers.google.com/open-source/licenses/bsd
use googletest::prelude::*;
use map_unittest_proto::TestMap;
use map_unittest_proto::{TestMap, TestMapWithMessages};
use paste::paste;
use std::collections::HashMap;
use unittest_proto::TestAllTypes;
macro_rules! generate_map_primitives_tests {
(
@ -145,3 +146,153 @@ fn test_bytes_and_string_copied() {
);
assert_that!(msg.map_int32_bytes_mut().get(1).unwrap(), eq(b"world"));
}
macro_rules! generate_map_with_msg_values_tests {
(
$(($k_field:ident, $k_nonzero:expr, $k_other:expr $(,)?)),*
$(,)?
) => {
paste! { $(
#[test]
fn [< test_map_ $k_field _all_types >]() {
// We need to cover the following upb/c++ thunks:
// TODO - b/323883851: Add test once Map::new is public.
// * new
// * free (covered implicitly by drop)
// * clear, size, insert, get, remove, iter, iter_next (all covered below)
let mut msg = TestMapWithMessages::new();
assert_that!(msg.[< map_ $k_field _all_types >]().len(), eq(0));
assert_that!(msg.[< map_ $k_field _all_types >]().get($k_nonzero), none());
// this block makes sure `insert` copies/moves, not borrows.
{
let mut msg_val = TestAllTypes::new();
msg_val.optional_int32_mut().set(1001);
assert_that!(
msg
.[< map_ $k_field _all_types_mut >]()
.insert($k_nonzero, msg_val.as_view()),
eq(true),
"`insert` should return true when key was inserted."
);
assert_that!(
msg
.[< map_ $k_field _all_types_mut >]()
.insert($k_nonzero, msg_val.as_view()),
eq(false),
"`insert` should return false when key was already present."
);
}
assert_that!(
msg.[< map_ $k_field _all_types >]().len(),
eq(1),
"`size` thunk should return correct len.");
assert_that!(
msg.[< map_ $k_field _all_types >]().get($k_nonzero),
some(anything()),
"`get` should return Some when key present.");
assert_that!(
msg.[< map_ $k_field _all_types >]().get($k_nonzero).unwrap().optional_int32(),
eq(1001));
assert_that!(
msg.[< map_ $k_field _all_types >]().get($k_other),
none(),
"`get` should return None when key missing.");
msg.[< map_ $k_field _all_types_mut >]().clear();
assert_that!(
msg.[< map_ $k_field _all_types >]().len(),
eq(0),
"`clear` should drop all elements.");
assert_that!(
msg.[< map_ $k_field _all_types_mut >]().insert($k_nonzero, TestAllTypes::new().as_view()),
eq(true));
assert_that!(
msg.[< map_ $k_field _all_types_mut >]().remove($k_nonzero),
eq(true),
"`remove` should return true when key was present.");
assert_that!(msg.[< map_ $k_field _all_types >]().len(), eq(0));
assert_that!(
msg.[< map_ $k_field _all_types_mut >]().remove($k_nonzero),
eq(false),
"`remove` should return false when key was missing.");
// empty iter
// assert_that!(
// msg.[< map_ $k_field _all_types_mut >]().iter().collect::<Vec<_>>(),
// elements_are![],
// "`iter` should work when empty."
// );
assert_that!(
msg.[< map_ $k_field _all_types_mut >]().keys().collect::<Vec<_>>(),
elements_are![],
"`iter` should work when empty."
);
assert_that!(
msg.[< map_ $k_field _all_types_mut >]().values().collect::<Vec<_>>(),
elements_are![],
"`iter` should work when empty."
);
// single element iter
assert_that!(
msg.[< map_ $k_field _all_types_mut >]().insert($k_nonzero, TestAllTypes::new().as_view()),
eq(true));
// assert_that!(
// msg.[< map_ $k_field _all_types >]().iter().collect::<Vec<_>>(),
// unordered_elements_are![
// eq(($k_nonzero, anything())),
// ]
// );
assert_that!(
msg.[< map_ $k_field _all_types >]().keys().collect::<Vec<_>>(),
unordered_elements_are![eq($k_nonzero)]
);
assert_that!(
msg.[< map_ $k_field _all_types >]().values().collect::<Vec<_>>().len(),
eq(1));
// 2 element iter
assert_that!(
msg
.[< map_ $k_field _all_types_mut >]()
.insert($k_other, TestAllTypes::new().as_view()),
eq(true));
assert_that!(
msg.[< map_ $k_field _all_types >]().iter().collect::<Vec<_>>().len(),
eq(2)
);
assert_that!(
msg.[< map_ $k_field _all_types >]().keys().collect::<Vec<_>>(),
unordered_elements_are![eq($k_nonzero), eq($k_other)]
);
assert_that!(
msg.[< map_ $k_field _all_types >]().values().collect::<Vec<_>>().len(),
eq(2)
);
}
)* }
}
}
generate_map_with_msg_values_tests!(
(int32, 1i32, 2i32),
(int64, 1i64, 2i64),
(uint32, 1u32, 2u32),
(uint64, 1u64, 2u64),
(sint32, 1, 2),
(sint64, 1, 2),
(fixed32, 1u32, 2u32),
(fixed64, 1u64, 2u64),
(sfixed32, 1, 2),
(sfixed64, 1, 2),
// TODO - b/324468833: fix msan failure
// (bool, true, false),
(string, "foo", "bar"),
);

@ -46,10 +46,10 @@ pub struct Arena {
extern "C" {
// `Option<NonNull<T: Sized>>` is ABI-compatible with `*mut T`
fn upb_Arena_New() -> Option<RawArena>;
fn upb_Arena_Free(arena: RawArena);
fn upb_Arena_Malloc(arena: RawArena, size: usize) -> *mut u8;
fn upb_Arena_Realloc(arena: RawArena, ptr: *mut u8, old: usize, new: usize) -> *mut u8;
pub fn upb_Arena_New() -> Option<RawArena>;
pub fn upb_Arena_Free(arena: RawArena);
pub fn upb_Arena_Malloc(arena: RawArena, size: usize) -> *mut u8;
pub fn upb_Arena_Realloc(arena: RawArena, ptr: *mut u8, old: usize, new: usize) -> *mut u8;
}
impl Arena {
@ -716,13 +716,24 @@ pub struct InnerMapMut<'msg> {
_phantom: PhantomData<&'msg Arena>,
}
#[doc(hidden)]
impl<'msg> InnerMapMut<'msg> {
pub fn new(_private: Private, raw: RawMap, raw_arena: RawArena) -> Self {
InnerMapMut { raw, raw_arena, _phantom: PhantomData }
}
#[doc(hidden)]
pub fn as_raw(&self, _private: Private) -> RawMap {
self.raw
}
#[doc(hidden)]
pub fn raw_arena(&self, _private: Private) -> RawArena {
self.raw_arena
}
}
trait UpbTypeConversions: Proxied {
pub trait UpbTypeConversions: Proxied {
fn upb_type() -> UpbCType;
fn to_message_value(val: View<'_, Self>) -> upb_MessageValue;
fn empty_message_value() -> upb_MessageValue;
@ -858,7 +869,7 @@ impl RawMapIter {
/// # Safety
/// - `self.map` must be valid, and remain valid while the return value is
/// in use.
pub(crate) unsafe fn next_unchecked(
pub unsafe fn next_unchecked(
&mut self,
_private: Private,
) -> Option<(upb_MessageValue, upb_MessageValue)> {
@ -986,7 +997,7 @@ impl_ProxiedInMapValue_for_key_types!(i32, u32, i64, u64, bool, ProtoStr);
#[repr(C)]
#[allow(dead_code)]
enum upb_MapInsertStatus {
pub enum upb_MapInsertStatus {
Inserted = 0,
Replaced = 1,
OutOfMemory = 2,
@ -1019,25 +1030,25 @@ pub unsafe fn upb_Map_InsertAndReturnIfInserted(
}
extern "C" {
fn upb_Map_New(arena: RawArena, key_type: UpbCType, value_type: UpbCType) -> RawMap;
fn upb_Map_Size(map: RawMap) -> usize;
fn upb_Map_Insert(
pub fn upb_Map_New(arena: RawArena, key_type: UpbCType, value_type: UpbCType) -> RawMap;
pub fn upb_Map_Size(map: RawMap) -> usize;
pub fn upb_Map_Insert(
map: RawMap,
key: upb_MessageValue,
value: upb_MessageValue,
arena: RawArena,
) -> upb_MapInsertStatus;
fn upb_Map_Get(map: RawMap, key: upb_MessageValue, value: *mut upb_MessageValue) -> bool;
fn upb_Map_Delete(
pub fn upb_Map_Get(map: RawMap, key: upb_MessageValue, value: *mut upb_MessageValue) -> bool;
pub fn upb_Map_Delete(
map: RawMap,
key: upb_MessageValue,
removed_value: *mut upb_MessageValue,
) -> bool;
fn upb_Map_Clear(map: RawMap);
pub fn upb_Map_Clear(map: RawMap);
static __rust_proto_kUpb_Map_Begin: usize;
fn upb_Map_Next(
pub fn upb_Map_Next(
map: RawMap,
key: *mut upb_MessageValue,
value: *mut upb_MessageValue,

@ -36,10 +36,8 @@ std::unique_ptr<AccessorGenerator> AccessorGeneratorFor(
auto value_type = field.message_type()->map_value()->type();
switch (value_type) {
case FieldDescriptor::TYPE_ENUM:
case FieldDescriptor::TYPE_MESSAGE:
return std::make_unique<UnsupportedField>(
"Maps with values of type enum and message are not "
"supported");
"Maps with values of type enum are not supported");
default:
return std::make_unique<Map>();
}

@ -5,6 +5,8 @@
// license that can be found in the LICENSE file or at
// https://developers.google.com/open-source/licenses/bsd
#include <string>
#include "google/protobuf/compiler/cpp/helpers.h"
#include "google/protobuf/compiler/rust/accessors/accessor_case.h"
#include "google/protobuf/compiler/rust/accessors/accessor_generator.h"
@ -115,13 +117,24 @@ void Map::InExternC(Context& ctx, const FieldDescriptor& field) const {
)rs");
}
std::string MapElementTypeName(FieldDescriptor::CppType cpp_type,
const Descriptor* message_type) {
if (cpp_type == FieldDescriptor::CPPTYPE_MESSAGE ||
cpp_type == FieldDescriptor::CPPTYPE_ENUM) {
return cpp::QualifiedClassName(message_type);
}
return cpp::PrimitiveTypeName(cpp_type);
}
void Map::InThunkCc(Context& ctx, const FieldDescriptor& field) const {
ctx.Emit(
{{"field", cpp::FieldName(&field)},
{"Key",
cpp::PrimitiveTypeName(field.message_type()->map_key()->cpp_type())},
MapElementTypeName(field.message_type()->map_key()->cpp_type(),
field.message_type()->map_key()->message_type())},
{"Value",
cpp::PrimitiveTypeName(field.message_type()->map_value()->cpp_type())},
MapElementTypeName(field.message_type()->map_value()->cpp_type(),
field.message_type()->map_value()->message_type())},
{"QualifiedMsg", cpp::QualifiedClassName(field.containing_type())},
{"getter_thunk", ThunkName(ctx, field, "get")},
{"getter_mut_thunk", ThunkName(ctx, field, "get_mut")},

@ -406,6 +406,335 @@ void MessageProxiedInRepeated(Context& ctx, const Descriptor& msg) {
ABSL_LOG(FATAL) << "unreachable";
}
// (
// 1) identifier used in thunk name - keep in sync with kMapKeyCppTypes!,
// 2) Rust key typename (K in Map<K, V>, so e.g. `[u8]` for bytes),
// 3) whether previous needs $pb$ prefix,
// 4) key typename as used in function parameter (e.g. `PtrAndLen` for bytes),
// 5) whether previous needs $pbi$ prefix,
// 6) expression converting `key` variable to expected thunk param type,
// 7) fn expression converting thunk param type to map key type (e.g. from
// PtrAndLen to &[u8]).
// 8) whether previous needs $pb$ prefix,
//)
struct MapKeyType {
absl::string_view thunk_ident;
absl::string_view key_t;
bool needs_key_t_prefix;
absl::string_view param_key_t;
bool needs_param_key_t_prefix;
absl::string_view key_expr;
absl::string_view from_ffi_key_expr;
bool needs_from_ffi_key_prefix;
};
constexpr MapKeyType kMapKeyTypes[] = {
{"i32", "i32", false, "i32", false, "key", "k", false},
{"u32", "u32", false, "u32", false, "key", "k", false},
{"i64", "i64", false, "i64", false, "key", "k", false},
{"u64", "u64", false, "u64", false, "key", "k", false},
{"bool", "bool", false, "bool", false, "key", "k", false},
{"string", "ProtoStr", true, "PtrAndLen", true, "key.as_bytes().into()",
"ProtoStr::from_utf8_unchecked(k.as_ref())", true},
{"bytes", "[u8]", false, "PtrAndLen", true, "key.into()", "k.as_ref()",
false}};
void MessageProxiedInMapValue(Context& ctx, const Descriptor& msg) {
switch (ctx.opts().kernel) {
case Kernel::kCpp:
for (const auto& t : kMapKeyTypes) {
ctx.Emit(
{{"map_new_thunk", RawMapThunk(ctx, msg, t.thunk_ident, "new")},
{"map_free_thunk", RawMapThunk(ctx, msg, t.thunk_ident, "free")},
{"map_clear_thunk", RawMapThunk(ctx, msg, t.thunk_ident, "clear")},
{"map_size_thunk", RawMapThunk(ctx, msg, t.thunk_ident, "size")},
{"map_insert_thunk",
RawMapThunk(ctx, msg, t.thunk_ident, "insert")},
{"map_get_thunk", RawMapThunk(ctx, msg, t.thunk_ident, "get")},
{"map_remove_thunk",
RawMapThunk(ctx, msg, t.thunk_ident, "remove")},
{"map_iter_thunk", RawMapThunk(ctx, msg, t.thunk_ident, "iter")},
{"map_iter_next_thunk",
RawMapThunk(ctx, msg, t.thunk_ident, "iter_next")},
{"key_expr", t.key_expr},
io::Printer::Sub({"param_key_t",
[&] {
if (t.needs_param_key_t_prefix) {
ctx.Emit({{"param_key_t", t.param_key_t}},
"$pbi$::$param_key_t$");
} else {
ctx.Emit({{"param_key_t", t.param_key_t}},
"$param_key_t$");
}
}})
.WithSuffix(""),
io::Printer::Sub({"key_t",
[&] {
if (t.needs_key_t_prefix) {
ctx.Emit({{"key_t", t.key_t}},
"$pb$::$key_t$");
} else {
ctx.Emit({{"key_t", t.key_t}}, "$key_t$");
}
}})
.WithSuffix(""),
io::Printer::Sub(
{"from_ffi_key_expr",
[&] {
if (t.needs_from_ffi_key_prefix) {
ctx.Emit({{"from_ffi_key_expr", t.from_ffi_key_expr}},
"$pb$::$from_ffi_key_expr$");
} else {
ctx.Emit({{"from_ffi_key_expr", t.from_ffi_key_expr}},
"$from_ffi_key_expr$");
}
}})
.WithSuffix("")},
R"rs(
extern "C" {
fn $map_new_thunk$() -> $pbi$::RawMap;
fn $map_free_thunk$(m: $pbi$::RawMap);
fn $map_clear_thunk$(m: $pbi$::RawMap);
fn $map_size_thunk$(m: $pbi$::RawMap) -> usize;
fn $map_insert_thunk$(m: $pbi$::RawMap, key: $param_key_t$, value: $pbi$::RawMessage) -> bool;
fn $map_get_thunk$(m: $pbi$::RawMap, key: $param_key_t$, value: *mut $pbi$::RawMessage) -> bool;
fn $map_remove_thunk$(m: $pbi$::RawMap, key: $param_key_t$, value: *mut $pbi$::RawMessage) -> bool;
fn $map_iter_thunk$(m: $pbi$::RawMap) -> $pbr$::UntypedMapIterator;
fn $map_iter_next_thunk$(iter: &mut $pbr$::UntypedMapIterator, key: *mut $param_key_t$, value: *mut $pbi$::RawMessage);
}
impl $pb$::ProxiedInMapValue<$key_t$> for $Msg$ {
fn map_new(_private: $pbi$::Private) -> $pb$::Map<$key_t$, Self> {
unsafe {
$pb$::Map::from_inner(
$pbi$::Private,
$pbr$::InnerMapMut::new($pbi$::Private, $map_new_thunk$())
)
}
}
unsafe fn map_free(_private: $pbi$::Private, map: &mut $pb$::Map<$key_t$, Self>) {
unsafe { $map_free_thunk$(map.as_raw($pbi$::Private)); }
}
fn map_clear(mut map: $pb$::Mut<'_, $pb$::Map<$key_t$, Self>>) {
unsafe { $map_clear_thunk$(map.as_raw($pbi$::Private)); }
}
fn map_len(map: $pb$::View<'_, $pb$::Map<$key_t$, Self>>) -> usize {
unsafe { $map_size_thunk$(map.as_raw($pbi$::Private)) }
}
fn map_insert(mut map: $pb$::Mut<'_, $pb$::Map<$key_t$, Self>>, key: $pb$::View<'_, $key_t$>, value: $pb$::View<'_, Self>) -> bool {
unsafe { $map_insert_thunk$(map.as_raw($pbi$::Private), $key_expr$, value.raw_msg()) }
}
fn map_get<'a>(map: $pb$::View<'a, $pb$::Map<$key_t$, Self>>, key: $pb$::View<'_, $key_t$>) -> Option<$pb$::View<'a, Self>> {
let key = $key_expr$;
let mut value = $std$::mem::MaybeUninit::uninit();
let found = unsafe { $map_get_thunk$(map.as_raw($pbi$::Private), key, value.as_mut_ptr()) };
if !found {
return None;
}
Some($Msg$View::new($pbi$::Private, unsafe { value.assume_init() }))
}
fn map_remove(mut map: $pb$::Mut<'_, $pb$::Map<$key_t$, Self>>, key: $pb$::View<'_, $key_t$>) -> bool {
let mut value = $std$::mem::MaybeUninit::uninit();
unsafe { $map_remove_thunk$(map.as_raw($pbi$::Private), $key_expr$, value.as_mut_ptr()) }
}
fn map_iter(map: $pb$::View<'_, $pb$::Map<$key_t$, Self>>) -> $pb$::MapIter<'_, $key_t$, Self> {
// SAFETY:
// - The backing map for `map.as_raw` is valid for at least '_.
// - A View that is live for '_ guarantees the backing map is unmodified for '_.
// - The `iter` function produces an iterator that is valid for the key
// and value types, and live for at least '_.
unsafe {
$pb$::MapIter::from_raw(
$pbi$::Private,
$map_iter_thunk$(map.as_raw($pbi$::Private))
)
}
}
fn map_iter_next<'a>(iter: &mut $pb$::MapIter<'a, $key_t$, Self>) -> Option<($pb$::View<'a, $key_t$>, $pb$::View<'a, Self>)> {
// SAFETY:
// - The `MapIter` API forbids the backing map from being mutated for 'a,
// and guarantees that it's the correct key and value types.
// - The thunk is safe to call as long as the iterator isn't at the end.
// - The thunk always writes to key and value fields and does not read.
// - The thunk does not increment the iterator.
unsafe {
iter.as_raw_mut($pbi$::Private).next_unchecked::<$key_t$, Self, _, _>(
$pbi$::Private,
$map_iter_next_thunk$,
|k| $from_ffi_key_expr$,
|raw_msg| $Msg$View::new($pbi$::Private, raw_msg)
)
}
}
}
)rs");
}
return;
case Kernel::kUpb:
ctx.Emit(
{
{"minitable", UpbMinitableName(msg)},
{"new_thunk", ThunkName(ctx, msg, "new")},
},
R"rs(
impl $pbr$::UpbTypeConversions for $Msg$ {
fn upb_type() -> $pbr$::UpbCType {
$pbr$::UpbCType::Message
}
fn to_message_value(
val: $pb$::View<'_, Self>) -> $pbr$::upb_MessageValue {
$pbr$::upb_MessageValue { msg_val: Some(val.raw_msg()) }
}
fn empty_message_value() -> $pbr$::upb_MessageValue {
Self::to_message_value(
$Msg$View::new(
$pbi$::Private,
$pbr$::ScratchSpace::zeroed_block($pbi$::Private)))
}
unsafe fn to_message_value_copy_if_required(
arena: $pbi$::RawArena,
val: $pb$::View<'_, Self>) -> $pbr$::upb_MessageValue {
// Self::to_message_value(val)
// SAFETY: The arena memory is not freed due to `ManuallyDrop`.
let cloned_msg = $pbr$::upb_Message_DeepClone(
val.raw_msg(), $std$::ptr::addr_of!($minitable$), arena)
.expect("upb_Message_DeepClone failed.");
Self::to_message_value(
$Msg$View::new($pbi$::Private, cloned_msg))
}
unsafe fn from_message_value<'msg>(msg: $pbr$::upb_MessageValue)
-> $pb$::View<'msg, Self> {
$Msg$View::new(
$pbi$::Private,
unsafe { msg.msg_val }
.expect("expected present message value in map"))
}
}
)rs");
for (const auto& t : kMapKeyTypes) {
ctx.Emit({io::Printer::Sub({"key_t",
[&] {
if (t.needs_key_t_prefix) {
ctx.Emit({{"key_t", t.key_t}},
"$pb$::$key_t$");
} else {
ctx.Emit({{"key_t", t.key_t}},
"$key_t$");
}
}})
.WithSuffix("")},
R"rs(
impl $pb$::ProxiedInMapValue<$key_t$> for $Msg$ {
fn map_new(_private: $pbi$::Private) -> $pb$::Map<$key_t$, Self> {
let arena = $pbr$::Arena::new();
let raw_arena = arena.raw();
std::mem::forget(arena);
unsafe {
$pb$::Map::from_inner(
$pbi$::Private,
$pbr$::InnerMapMut::new(
$pbi$::Private,
$pbr$::upb_Map_New(
raw_arena,
<$key_t$ as $pbr$::UpbTypeConversions>::upb_type(),
<Self as $pbr$::UpbTypeConversions>::upb_type()),
raw_arena))
}
}
unsafe fn map_free(_private: $pbi$::Private, map: &mut $pb$::Map<$key_t$, Self>) {
// SAFETY:
// - `map.raw_arena($pbi$::Private)` is a live `upb_Arena*`
// - This function is only called once for `map` in `Drop`.
unsafe {
$pbr$::upb_Arena_Free(map.inner($pbi$::Private).raw_arena($pbi$::Private));
}
}
fn map_clear(mut map: $pb$::Mut<'_, $pb$::Map<$key_t$, Self>>) {
unsafe {
$pbr$::upb_Map_Clear(map.as_raw($pbi$::Private));
}
}
fn map_len(map: $pb$::View<'_, $pb$::Map<$key_t$, Self>>) -> usize {
unsafe {
$pbr$::upb_Map_Size(map.as_raw($pbi$::Private))
}
}
fn map_insert(mut map: $pb$::Mut<'_, $pb$::Map<$key_t$, Self>>, key: $pb$::View<'_, $key_t$>, value: $pb$::View<'_, Self>) -> bool {
let arena = map.inner($pbi$::Private).raw_arena($pbi$::Private);
unsafe {
$pbr$::upb_Map_InsertAndReturnIfInserted(
map.as_raw($pbi$::Private),
<$key_t$ as $pbr$::UpbTypeConversions>::to_message_value(key),
<Self as $pbr$::UpbTypeConversions>::to_message_value_copy_if_required(arena, value),
arena
)
}
}
fn map_get<'a>(map: $pb$::View<'a, $pb$::Map<$key_t$, Self>>, key: $pb$::View<'_, $key_t$>) -> Option<$pb$::View<'a, Self>> {
let mut val = <Self as $pbr$::UpbTypeConversions>::empty_message_value();
let found = unsafe {
$pbr$::upb_Map_Get(
map.as_raw($pbi$::Private),
<$key_t$ as $pbr$::UpbTypeConversions>::to_message_value(key),
&mut val)
};
if !found {
return None;
}
Some(unsafe { <Self as $pbr$::UpbTypeConversions>::from_message_value(val) })
}
fn map_remove(mut map: $pb$::Mut<'_, $pb$::Map<$key_t$, Self>>, key: $pb$::View<'_, $key_t$>) -> bool {
let mut val = <Self as $pbr$::UpbTypeConversions>::empty_message_value();
unsafe {
$pbr$::upb_Map_Delete(
map.as_raw($pbi$::Private),
<$key_t$ as $pbr$::UpbTypeConversions>::to_message_value(key),
&mut val)
}
}
fn map_iter(map: $pb$::View<'_, $pb$::Map<$key_t$, Self>>) -> $pb$::MapIter<'_, $key_t$, Self> {
// SAFETY: View<Map<'_,..>> guarantees its RawMap outlives '_.
unsafe {
$pb$::MapIter::from_raw($pbi$::Private, $pbr$::RawMapIter::new($pbi$::Private, map.as_raw($pbi$::Private)))
}
}
fn map_iter_next<'a>(
iter: &mut $pb$::MapIter<'a, $key_t$, Self>
) -> Option<($pb$::View<'a, $key_t$>, $pb$::View<'a, Self>)> {
// SAFETY: MapIter<'a, ..> guarantees its RawMapIter outlives 'a.
unsafe { iter.as_raw_mut($pbi$::Private).next_unchecked($pbi$::Private) }
// SAFETY: MapIter<K, V> returns key and values message values
// with the variants for K and V active.
.map(|(k, v)| unsafe {(
<$key_t$ as $pbr$::UpbTypeConversions>::from_message_value(k),
<Self as $pbr$::UpbTypeConversions>::from_message_value(v),
)})
}
}
)rs");
}
}
}
} // namespace
void GenerateRs(Context& ctx, const Descriptor& msg) {
@ -522,6 +851,7 @@ void GenerateRs(Context& ctx, const Descriptor& msg) {
{"settable_impl_for_view",
[&] { MessageSettableValueForView(ctx, msg); }},
{"repeated_impl", [&] { MessageProxiedInRepeated(ctx, msg); }},
{"map_value_impl", [&] { MessageProxiedInMapValue(ctx, msg); }},
{"unwrap_upb",
[&] {
if (ctx.is_upb()) {
@ -567,7 +897,7 @@ void GenerateRs(Context& ctx, const Descriptor& msg) {
impl<'msg> $Msg$View<'msg> {
#[doc(hidden)]
pub fn new(_private: $pbi$::Private, msg: $pbi$::RawMessage) -> Self {
Self { msg, _phantom: std::marker::PhantomData }
Self { msg, _phantom: $std$::marker::PhantomData }
}
fn raw_msg(&self) -> $pbi$::RawMessage {
@ -668,6 +998,7 @@ void GenerateRs(Context& ctx, const Descriptor& msg) {
}
$repeated_impl$
$map_value_impl$
#[derive(Debug)]
#[allow(dead_code)]
@ -724,10 +1055,10 @@ void GenerateRs(Context& ctx, const Descriptor& msg) {
impl<'msg> $pb$::ViewProxy<'msg> for $Msg$Mut<'msg> {
type Proxied = $Msg$;
fn as_view(&self) -> $pb$::View<'_, $Msg$> {
$Msg$View { msg: self.raw_msg(), _phantom: std::marker::PhantomData }
$Msg$View { msg: self.raw_msg(), _phantom: $std$::marker::PhantomData }
}
fn into_view<'shorter>(self) -> $pb$::View<'shorter, $Msg$> where 'msg: 'shorter {
$Msg$View { msg: self.raw_msg(), _phantom: std::marker::PhantomData }
$Msg$View { msg: self.raw_msg(), _phantom: $std$::marker::PhantomData }
}
}
@ -893,6 +1224,108 @@ void GenerateThunksCc(Context& ctx, const Descriptor& msg) {
$nested_msg_thunks$
)cc");
// (
// 1) identifier used in thunk name - keep in sync with kMapKeyTypes!,
// 2) c+ key typename (K in Map<K, V>, so e.g. `std::string` for bytes),
// 3) key typename as used in function parameter (e.g. `PtrAndLen` for
// bytes),
// 4) expression converting `key` variable to expected C++ type,
// 5) expression convering `key` variable from the C++ type back into the
// FFI type.
//)
struct MapKeyCppTypes {
absl::string_view thunk_ident;
absl::string_view key_t;
absl::string_view param_key_t;
absl::string_view key_expr;
absl::string_view to_ffi_key_expr;
};
constexpr MapKeyCppTypes kMapKeyCppTypes[] = {
{"i32", "int32_t", "int32_t", "key", "cpp_key"},
{"u32", "uint32_t", "uint32_t", "key", "cpp_key"},
{"i64", "int64_t", "int64_t", "key", "cpp_key"},
{"u64", "uint64_t", "uint64_t", "key", "cpp_key"},
{"bool", "bool", "bool", "key", "cpp_key"},
{"string", "std::string", "google::protobuf::rust_internal::PtrAndLen",
"std::string(key.ptr, key.len)",
"google::protobuf::rust_internal::PtrAndLen(cpp_key.data(), cpp_key.size())"},
{"bytes", "std::string", "google::protobuf::rust_internal::PtrAndLen",
"std::string(key.ptr, key.len)",
"google::protobuf::rust_internal::PtrAndLen(cpp_key.data(), cpp_key.size())"}};
for (const auto& t : kMapKeyCppTypes) {
ctx.Emit(
{
{"map_new_thunk", RawMapThunk(ctx, msg, t.thunk_ident, "new")},
{"map_free_thunk", RawMapThunk(ctx, msg, t.thunk_ident, "free")},
{"map_clear_thunk", RawMapThunk(ctx, msg, t.thunk_ident, "clear")},
{"map_size_thunk", RawMapThunk(ctx, msg, t.thunk_ident, "size")},
{"map_insert_thunk",
RawMapThunk(ctx, msg, t.thunk_ident, "insert")},
{"map_get_thunk", RawMapThunk(ctx, msg, t.thunk_ident, "get")},
{"map_remove_thunk",
RawMapThunk(ctx, msg, t.thunk_ident, "remove")},
{"map_iter_thunk", RawMapThunk(ctx, msg, t.thunk_ident, "iter")},
{"map_iter_next_thunk",
RawMapThunk(ctx, msg, t.thunk_ident, "iter_next")},
{"key_t", t.key_t},
{"param_key_t", t.param_key_t},
{"key_expr", t.key_expr},
{"to_ffi_key_expr", t.to_ffi_key_expr},
{"pkg::Msg", cpp::QualifiedClassName(&msg)},
{"abi", "\"C\""}, // Workaround for syntax highlight bug in VSCode.
},
R"cc(
extern $abi$ {
const google::protobuf::Map<$key_t$, $pkg::Msg$>* $map_new_thunk$() {
return new google::protobuf::Map<$key_t$, $pkg::Msg$>();
}
void $map_free_thunk$(const google::protobuf::Map<$key_t$, $pkg::Msg$>* m) { delete m; }
void $map_clear_thunk$(google::protobuf::Map<$key_t$, $pkg::Msg$> * m) { m->clear(); }
size_t $map_size_thunk$(const google::protobuf::Map<$key_t$, $pkg::Msg$>* m) {
return m->size();
}
bool $map_insert_thunk$(google::protobuf::Map<$key_t$, $pkg::Msg$> * m,
$param_key_t$ key, $pkg::Msg$ value) {
auto k = $key_expr$;
auto it = m->find(k);
if (it != m->end()) {
return false;
}
(*m)[k] = value;
return true;
}
bool $map_get_thunk$(const google::protobuf::Map<$key_t$, $pkg::Msg$>* m,
$param_key_t$ key, const $pkg::Msg$** value) {
auto it = m->find($key_expr$);
if (it == m->end()) {
return false;
}
const $pkg::Msg$* cpp_value = &it->second;
*value = cpp_value;
return true;
}
bool $map_remove_thunk$(google::protobuf::Map<$key_t$, $pkg::Msg$> * m,
$param_key_t$ key, $pkg::Msg$ * value) {
auto num_removed = m->erase($key_expr$);
return num_removed > 0;
}
google::protobuf::internal::UntypedMapIterator $map_iter_thunk$(
const google::protobuf::Map<$key_t$, $pkg::Msg$>* m) {
return google::protobuf::internal::UntypedMapIterator::FromTyped(m->cbegin());
}
void $map_iter_next_thunk$(
const google::protobuf::internal::UntypedMapIterator* iter,
$param_key_t$* key, const $pkg::Msg$** value) {
auto typed_iter = iter->ToTyped<
google::protobuf::Map<$key_t$, $pkg::Msg$>::const_iterator>();
const auto& cpp_key = typed_iter->first;
const auto& cpp_value = typed_iter->second;
*key = $to_ffi_key_expr$;
*value = &cpp_value;
}
}
)cc");
}
}
} // namespace rust

@ -67,6 +67,12 @@ std::string GetHeaderFile(Context& ctx, const FileDescriptor& file) {
return absl::StrCat(basename, ".proto.h");
}
std::string RawMapThunk(Context& ctx, const Descriptor& msg,
absl::string_view key_t, absl::string_view op) {
return absl::StrCat("__rust_proto_thunk__", key_t, "_",
GetUnderscoreDelimitedFullName(ctx, *&msg), "_", op);
}
namespace {
template <typename T>

@ -33,6 +33,8 @@ std::string ThunkName(Context& ctx, const OneofDescriptor& field,
std::string ThunkName(Context& ctx, const Descriptor& msg,
absl::string_view op);
std::string RawMapThunk(Context& ctx, const Descriptor& msg,
absl::string_view key_t, absl::string_view op);
// Returns the local constant that defines the vtable for mutating `field`.
std::string VTableName(const FieldDescriptor& field);

@ -39,6 +39,21 @@ message TestMap {
map<int32, TestAllTypes> map_int32_all_types = 19;
}
message TestMapWithMessages {
map<int32, TestAllTypes> map_int32_all_types = 1;
map<int64, TestAllTypes> map_int64_all_types = 2;
map<uint32, TestAllTypes> map_uint32_all_types = 3;
map<uint64, TestAllTypes> map_uint64_all_types = 4;
map<sint32, TestAllTypes> map_sint32_all_types = 5;
map<sint64, TestAllTypes> map_sint64_all_types = 6;
map<fixed32, TestAllTypes> map_fixed32_all_types = 7;
map<fixed64, TestAllTypes> map_fixed64_all_types = 8;
map<sfixed32, TestAllTypes> map_sfixed32_all_types = 9;
map<sfixed64, TestAllTypes> map_sfixed64_all_types = 10;
map<bool, TestAllTypes> map_bool_all_types = 11;
map<string, TestAllTypes> map_string_all_types = 12;
}
message TestMapSubmessage {
TestMap test_map = 1;
}

Loading…
Cancel
Save