From 9bcd7087ae2fc560fec21ef83867394ca8409a76 Mon Sep 17 00:00:00 2001 From: Derek Benson Date: Tue, 1 Oct 2024 07:09:38 -0700 Subject: [PATCH] Add support for instantiating Maps using proto!. Also adds an IntoProxied implementation for Maps that accepts an Iterator of (K, V), just like what we have for Repeated fields. Migrates the helper functions used by the macro into __internal so that they are more hidden in Cargo. PiperOrigin-RevId: 681004963 --- rust/internal.rs | 19 +++ rust/map.rs | 15 +++ rust/proto_macro.rs | 191 ++++++++++++++++++++++++++- rust/shared.rs | 7 - rust/test/shared/BUILD | 2 + rust/test/shared/proto_macro_test.rs | 29 ++++ 6 files changed, 252 insertions(+), 11 deletions(-) diff --git a/rust/internal.rs b/rust/internal.rs index b8849373cf..f3a025e2fd 100644 --- a/rust/internal.rs +++ b/rust/internal.rs @@ -13,7 +13,10 @@ pub use paste::paste; pub use crate::r#enum::Enum; +use crate::map; +use crate::repeated; pub use crate::ProtoStr; +use crate::Proxied; pub use std::fmt::Debug; // TODO: Temporarily re-export these symbols which are now under @@ -37,3 +40,19 @@ pub trait SealedInternal {} pub trait MatcherEq: SealedInternal + Debug { fn matches(&self, o: &Self) -> bool; } + +/// Used by the proto! macro to get a default value for a repeated field. +pub fn get_repeated_default_value( + _: Private, + _: repeated::RepeatedView<'_, T>, +) -> T { + Default::default() +} + +/// Used by the proto! macro to get a default value for a map field. +pub fn get_map_default_value + Default>( + _: Private, + _: map::MapView<'_, K, V>, +) -> V { + Default::default() +} diff --git a/rust/map.rs b/rust/map.rs index 6dfa88236b..386b562144 100644 --- a/rust/map.rs +++ b/rust/map.rs @@ -400,6 +400,21 @@ where } } +impl<'msg, 'k, 'v, K, KView, V, VView, I> IntoProxied> for I +where + I: Iterator, + K: Proxied + 'msg + 'k, + V: ProxiedInMapValue + 'msg + 'v, + KView: Into>, + VView: IntoProxied, +{ + fn into_proxied(self, _private: Private) -> Map { + let mut m = Map::::new(); + m.as_mut().extend(self); + m + } +} + /// An iterator visiting all key-value pairs in arbitrary order. /// /// The iterator element type is `(View, View)`. diff --git a/rust/proto_macro.rs b/rust/proto_macro.rs index 19e1063dad..7eb17b9d29 100644 --- a/rust/proto_macro.rs +++ b/rust/proto_macro.rs @@ -117,7 +117,7 @@ macro_rules! proto_internal { proto_internal!(@array $msg $repeated [ $($vals),+ , { - let mut $msg = $crate::get_repeated_default_value($crate::__internal::Private, $repeated); + let mut $msg = $crate::__internal::get_repeated_default_value($crate::__internal::Private, $repeated); proto_internal!(@merge $msg $($value)*); proto_internal!(@msg $msg $($value)*); $msg @@ -130,7 +130,7 @@ macro_rules! proto_internal { [ $($vals),+ , { - let mut $msg = $crate::get_repeated_default_value($crate::__internal::Private, $repeated); + let mut $msg = $crate::__internal::get_repeated_default_value($crate::__internal::Private, $repeated); proto_internal!(@merge $msg $($value)*); proto_internal!(@msg $msg $($value)*); $msg @@ -142,7 +142,7 @@ macro_rules! proto_internal { (@array $msg:ident $repeated:ident [] __ { $($value:tt)* }, $($rest:tt)*) => { proto_internal!(@array $msg $repeated [ { - let mut $msg = $crate::get_repeated_default_value($crate::__internal::Private, $repeated); + let mut $msg = $crate::__internal::get_repeated_default_value($crate::__internal::Private, $repeated); proto_internal!(@merge $msg $($value)*); proto_internal!(@msg $msg $($value)*); $msg @@ -154,7 +154,7 @@ macro_rules! proto_internal { (@array $msg:ident $repeated:ident [] __ { $($value:tt)* }) => { [ { - let mut $msg = $crate::get_repeated_default_value($crate::__internal::Private, $repeated); + let mut $msg = $crate::__internal::get_repeated_default_value($crate::__internal::Private, $repeated); proto_internal!(@merge $msg $($value)*); proto_internal!(@msg $msg $($value)*); $msg @@ -257,6 +257,189 @@ macro_rules! proto_internal { ] }; + // Begin handling (key, value) for Maps in array literals + // Message nested in array literal with trailing array items + (@array $msg:ident $map:ident [$($vals:expr),+] ($key:expr, __ { $($value:tt)* }), $($rest:tt)*) => { + proto_internal!(@array $msg $map [ + $($vals),+ , + ( + $key, + { + let mut $msg = $crate::__internal::get_map_default_value($crate::__internal::Private, $map); + proto_internal!(@merge $msg $($value)*); + proto_internal!(@msg $msg $($value)*); + $msg + } + ) + ] $($rest)*) + }; + + // Message nested in [] literal + (@array $msg:ident $map:ident [$($vals:expr),+] ($key:expr, __ { $($value:tt)* })) => { + [ + $($vals),+ , + ( + $key, + { + let mut $msg = $crate::__internal::get_map_default_value($crate::__internal::Private, $map); + proto_internal!(@merge $msg $($value)*); + proto_internal!(@msg $msg $($value)*); + $msg + } + ) + ] + }; + + // Message nested in array literal with trailing array items + (@array $msg:ident $map:ident [] ($key:expr, __ { $($value:tt)* }), $($rest:tt)*) => { + proto_internal!(@array $msg $map [ + ( + $key, + { + let mut $msg = $crate::__internal::get_map_default_value($crate::__internal::Private, $map); + proto_internal!(@merge $msg $($value)*); + proto_internal!(@msg $msg $($value)*); + $msg + } + ) + ] $($rest)*) + }; + + // Message nested in [] literal + (@array $msg:ident $map:ident [] ($key:expr, __ { $($value:tt)* })) => { + [ + ( + $key, + { + let mut $msg = $crate::__internal::get_map_default_value($crate::__internal::Private, $map); + proto_internal!(@merge $msg $($value)*); + proto_internal!(@msg $msg $($value)*); + $msg + } + ) + ] + }; + + // End of __ repeated, now we need to handle named types + + // Message nested in array literal with trailing array items + (@array $msg:ident $map:ident [$($vals:expr),+] ($key:expr, $($msgtype:ident)::+ { $($value:tt)* }), $($rest:tt)*) => { + proto_internal!(@array $msg $map [ + $($vals),+ , + ( + $key, + { + let mut $msg = $($msgtype)::+::new(); + proto_internal!(@merge $msg $($value)*); + proto_internal!(@msg $msg $($value)*); + $msg + } + ) + ] $($rest)*) + }; + // Message nested in [] literal with leading :: on type and trailing array items + (@array $msg:ident $map:ident [$($vals:expr),+] ($key:expr, ::$($msgtype:ident)::+ { $($value:tt)* }), $($rest:tt)*) => { + proto_internal!(@array $msg $map [ + $($vals),+ , + ( + $key, + { + let mut $msg = ::$($msgtype)::+::new(); + proto_internal!(@merge $msg $($value)*); + proto_internal!(@msg $msg $($value)*); + $msg + } + ) + ] $($rest)*) + }; + // Message nested in [] literal + (@array $msg:ident $map:ident [$($vals:expr),+] ($key:expr, $($msgtype:ident)::+ { $($value:tt)* })) => { + [ + $($vals),+ , + ( + $key, + { + let mut $msg = $($msgtype)::+::new(); + proto_internal!(@merge $msg $($value)*); + proto_internal!(@msg $msg $($value)*); + $msg + } + ) + ] + }; + // Message nested in [] literal with leading :: on type + (@array $msg:ident $map:ident [$($vals:expr),+] ($key:expr, ::$($msgtype:ident)::+ { $($value:tt)* })) => { + [ + $($vals),+ , + ( + $key, + { + let mut $msg = ::$($msgtype)::+::new(); + proto_internal!(@merge $msg $($value)*); + proto_internal!(@msg $msg $($value)*); + $msg + } + ) + ] + }; + + // Message nested in array literal with trailing array items + (@array $msg:ident $map:ident [] ($key:expr, $($msgtype:ident)::+ { $($value:tt)* }), $($rest:tt)*) => { + proto_internal!(@array $msg $map [ + ( + $key, + { + let mut $msg = $($msgtype)::+::new(); + proto_internal!(@merge $msg $($value)*); + proto_internal!(@msg $msg $($value)*); + $msg + } + ) + ] $($rest)*) + }; + // with leading :: + (@array $msg:ident $map:ident [] ($key:expr, ::$($msgtype:ident)::+ { $($value:tt)* }), $($rest:tt)*) => { + proto_internal!(@array $msg $map [ + ( + $key, + { + let mut $msg = ::$($msgtype)::+::new(); + proto_internal!(@merge $msg $($value)*); + proto_internal!(@msg $msg $($value)*); + $msg + } + ) + ] $($rest)*) + }; + // Message nested in [] literal + (@array $msg:ident $map:ident [] ($key:expr, $($msgtype:ident)::+ { $($value:tt)* })) => { + [ + ( + $key, + { + let mut $msg = $($msgtype)::+::new(); + proto_internal!(@merge $msg $($value)*); + proto_internal!(@msg $msg $($value)*); + $msg + } + ) + ] + }; + (@array $msg:ident $map:ident [] ::($key:expr, $($msgtype:ident)::+ { $($value:tt)* })) => { + [ + ( + $key, + { + let mut $msg = ::$($msgtype)::+::new(); + proto_internal!(@merge $msg $($value)*); + proto_internal!(@msg $msg $($value)*); + $msg + } + ) + ] + }; + // End handling of (key, value) for Maps + (@array $msg:ident $repeated:ident [$($vals:expr),+] $expr:expr, $($rest:tt)*) => { proto_internal!(@array $msg $repeated [$($vals),+, $expr] $($rest)*) }; diff --git a/rust/shared.rs b/rust/shared.rs index 8e42de9650..fdca7db028 100644 --- a/rust/shared.rs +++ b/rust/shared.rs @@ -116,10 +116,3 @@ impl fmt::Display for SerializeError { write!(f, "Couldn't serialize proto into bytes (depth too deep or missing required fields)") } } - -pub fn get_repeated_default_value( - _: __internal::Private, - _: repeated::RepeatedView<'_, T>, -) -> T { - Default::default() -} diff --git a/rust/test/shared/BUILD b/rust/test/shared/BUILD index 43453f179f..faa4bf8d7c 100644 --- a/rust/test/shared/BUILD +++ b/rust/test/shared/BUILD @@ -456,6 +456,7 @@ rust_test( }, deps = [ "//rust:protobuf_cpp", + "//rust/test:map_unittest_cpp_rust_proto", "//src/google/protobuf:unittest_cpp_rust_proto", "@crate_index//:googletest", ], @@ -471,6 +472,7 @@ rust_test( deps = [ "//rust:protobuf_gtest_matchers_upb", "//rust:protobuf_upb", + "//rust/test:map_unittest_upb_rust_proto", "//src/google/protobuf:unittest_upb_rust_proto", "@crate_index//:googletest", ], diff --git a/rust/test/shared/proto_macro_test.rs b/rust/test/shared/proto_macro_test.rs index fa86f5dd7f..45ba49ddff 100644 --- a/rust/test/shared/proto_macro_test.rs +++ b/rust/test/shared/proto_macro_test.rs @@ -14,6 +14,8 @@ use unittest_rust_proto::{ NestedTestAllTypes, TestAllTypes, }; +use map_unittest_rust_proto::{TestMap, TestMapWithMessages}; + struct TestValue { val: i64, } @@ -202,3 +204,30 @@ fn test_repeated_msg() { assert_that!(msg.repeated_child().get(0).unwrap().payload().optional_int32(), eq(1)); assert_that!(msg.repeated_child().get(1).unwrap().payload().optional_int32(), eq(2)); } + +#[gtest] +fn test_string_maps() { + let msg = + proto!(TestMap { map_string_string: [("foo", "bar"), ("baz", "qux"), ("quux", "quuz")] }); + assert_that!(msg.map_string_string().len(), eq(3)); + assert_that!(msg.map_string_string().get("foo").unwrap(), eq("bar")); + assert_that!(msg.map_string_string().get("baz").unwrap(), eq("qux")); + assert_that!(msg.map_string_string().get("quux").unwrap(), eq("quuz")); +} + +#[gtest] +fn test_message_maps() { + let msg3 = proto!(TestAllTypes { optional_int32: 3 }); + let kv3 = ("quux", msg3); + let msg = proto!(TestMapWithMessages { + map_string_all_types: [ + ("foo", TestAllTypes { optional_int32: 1 }), + ("baz", __ { optional_int32: 2 }), + kv3 + ] + }); + assert_that!(msg.map_string_all_types().len(), eq(3)); + assert_that!(msg.map_string_all_types().get("foo").unwrap().optional_int32(), eq(1)); + assert_that!(msg.map_string_all_types().get("baz").unwrap().optional_int32(), eq(2)); + assert_that!(msg.map_string_all_types().get("quux").unwrap().optional_int32(), eq(3)); +}