diff --git a/rust/proto_macro.rs b/rust/proto_macro.rs index a70afffa62..b3f937d605 100644 --- a/rust/proto_macro.rs +++ b/rust/proto_macro.rs @@ -40,6 +40,8 @@ macro_rules! proto { #[macro_export(local_inner_macros)] #[doc(hidden)] macro_rules! proto_internal { + // @merge rules are used to find a trailing ..expr on the message and call merge_from on it + // before the fields of the message are set. (@merge $msg:ident $ident:ident : $expr:expr, $($rest:tt)*) => { proto_internal!(@merge $msg $($rest)*); }; @@ -51,52 +53,224 @@ macro_rules! proto_internal { $msg.merge_from($expr); }; - // nested message, - (@msg $msg:ident $submsg:ident : $($msgtype:ident)::+ { $field:ident : $($value:tt)* }, $($rest:tt)*) => { - proto_internal!(@msg $msg $submsg : $($msgtype)::+ { $field : $($value)* }); + // @msg rules are used to set the fields of the message. There is a lot of duplication here + // because we need to parse the message type using a :: separated list of identifiers. + // nested message and trailing fields + (@msg $msg:ident $submsg:ident : $($msgtype:ident)::+ { $($value:tt)* }, $($rest:tt)*) => { + proto_internal!(@msg $msg $submsg : $($msgtype)::+ { $($value)* }); proto_internal!(@msg $msg $($rest)*); }; - (@msg $msg:ident $submsg:ident : ::$($msgtype:ident)::+ { $field:ident : $($value:tt)* }, $($rest:tt)*) => { - proto_internal!(@msg $msg $submsg : ::$($msgtype)::+ { $field : $($value)* }); + // nested message with leading :: on type and trailing fields + (@msg $msg:ident $submsg:ident : ::$($msgtype:ident)::+ { $($value:tt)* }, $($rest:tt)*) => { + proto_internal!(@msg $msg $submsg : ::$($msgtype)::+ { $($value)* }); proto_internal!(@msg $msg $($rest)*); }; + // nested message using __ + (@msg $msg:ident $submsg:ident : __ { $($value:tt)* }) => { + { + let mut $msg = $crate::__internal::paste!($msg.[<$submsg _mut>]()); + proto_internal!(@merge $msg $($value)*); + proto_internal!(@msg $msg $($value)*); + } + }; // nested message - (@msg $msg:ident $submsg:ident : $($msgtype:ident)::+ { $field:ident : $($value:tt)* }) => { + (@msg $msg:ident $submsg:ident : $($msgtype:ident)::+ { $($value:tt)* }) => { { let mut $msg: <$($msgtype)::+ as $crate::MutProxied>::Mut<'_> = $crate::__internal::paste!($msg.[<$submsg _mut>]()); - proto_internal!(@merge $msg $field : $($value)*); - proto_internal!(@msg $msg $field : $($value)*); + proto_internal!(@merge $msg $($value)*); + proto_internal!(@msg $msg $($value)*); } }; - (@msg $msg:ident $submsg:ident : ::$($msgtype:ident)::+ { $field:ident : $($value:tt)* }) => { + // nested message with leading :: + (@msg $msg:ident $submsg:ident : ::$($msgtype:ident)::+ { $($value:tt)* }) => { { let mut $msg: <::$($msgtype)::+ as $crate::MutProxied>::Mut<'_> = $crate::__internal::paste!($msg.[<$submsg _mut>]()); - proto_internal!(@merge $msg $field : $($value)*); - proto_internal!(@msg $msg $field : $($value)*); + proto_internal!(@merge $msg $($value)*); + proto_internal!(@msg $msg $($value)*); } }; - // empty nested message, - (@msg $msg:ident $submsg:ident : $($msgtype:ident)::+ { }, $($rest:tt)*) => { - proto_internal!(@msg $msg $submsg : $($msgtype)::+ { }); - proto_internal!(@msg $msg $($rest)*); - }; - (@msg $msg:ident $submsg:ident : ::$($msgtype:ident)::+ { }, $($rest:tt)*) => { - proto_internal!(@msg $msg $submsg : ::$($msgtype)::+ { }); + // field with array literal and trailing fields + (@msg $msg:ident $ident:ident : [$($elems:tt)*], $($rest:tt)*) => { + proto_internal!(@msg $msg $ident : [$($elems)*]); proto_internal!(@msg $msg $($rest)*); }; - - // empty nested message - (@msg $msg:ident $submsg:ident : $($msgtype:ident)::+ { }) => { + // field with array literal, calls out to @array to look for nested messages + (@msg $msg:ident $ident:ident : [$($elems:tt)*]) => { { - let mut $msg: <$($msgtype)::+ as $crate::MutProxied>::Mut<'_> = $crate::__internal::paste!($msg.[<$submsg _mut>]()); + let _repeated = $crate::__internal::paste!($msg.[<$ident>]()); + let elems = proto_internal!(@array $msg _repeated [] $($elems)*); + $crate::__internal::paste!($msg.[](elems.into_iter())); } }; - (@msg $msg:ident $submsg:ident : ::$($msgtype:ident)::+ { }) => { - { - let mut $msg: <::$($msgtype)::+ as $crate::MutProxied>::Mut<'_> = $crate::__internal::paste!($msg.[<$submsg _mut>]()); - } + + // @array searches through an array literal for nested messages. + // If a message is found then we recursively call the macro on it to set the fields. + // This will create an array literal of owned messages to be used while setting the field. + // For primitive types they should just be passed through as an $expr. + // The array literal is constructed recursively, so the [] case has to be handled separately so + // that we can properly insert commas. This leads to a lot of duplication. + + // Message nested in array literal with trailing array items + (@array $msg:ident $repeated:ident [$($vals:expr),+] __ { $($value:tt)* }, $($rest:tt)*) => { + proto_internal!(@array $msg $repeated [ + $($vals),+ , + { + let mut $msg = $crate::get_repeated_default_value($crate::__internal::Private, $repeated); + proto_internal!(@merge $msg $($value)*); + proto_internal!(@msg $msg $($value)*); + $msg + } + ] $($rest)*) + }; + + // Message nested in [] literal + (@array $msg:ident $repeated:ident [$($vals:expr),+] __ { $($value:tt)* }) => { + [ + $($vals),+ , + { + let mut $msg = $crate::get_repeated_default_value($crate::__internal::Private, $repeated); + proto_internal!(@merge $msg $($value)*); + proto_internal!(@msg $msg $($value)*); + $msg + } + ] + }; + + // Message nested in array literal with trailing array items + (@array $msg:ident $repeated:ident [] __ { $($value:tt)* }, $($rest:tt)*) => { + proto_internal!(@array $msg $repeated [ + { + let mut $msg = $crate::get_repeated_default_value($crate::__internal::Private, $repeated); + proto_internal!(@merge $msg $($value)*); + proto_internal!(@msg $msg $($value)*); + $msg + } + ] $($rest)*) + }; + + // Message nested in [] literal + (@array $msg:ident $repeated:ident [] __ { $($value:tt)* }) => { + [ + { + let mut $msg = $crate::get_repeated_default_value($crate::__internal::Private, $repeated); + proto_internal!(@merge $msg $($value)*); + proto_internal!(@msg $msg $($value)*); + $msg + } + ] + }; + + // End of __ repeated, now we need to handle named types + + // Message nested in array literal with trailing array items + (@array $msg:ident $repeated:ident [$($vals:expr),+] $($msgtype:ident)::+ { $($value:tt)* }, $($rest:tt)*) => { + proto_internal!(@array $msg $repeated [ + $($vals),+ , + { + let mut $msg = $($msgtype)::+::new(); + proto_internal!(@merge $msg $($value)*); + proto_internal!(@msg $msg $($value)*); + $msg + } + ] $($rest)*) + }; + // Message nested in [] literal with leading :: on type and trailing array items + (@array $msg:ident $repeated:ident [$($vals:expr),+] ::$($msgtype:ident)::+ { $($value:tt)* }, $($rest:tt)*) => { + proto_internal!(@array $msg $repeated [ + $($vals),+ , + { + let mut $msg = ::$($msgtype)::+::new(); + proto_internal!(@merge $msg $($value)*); + proto_internal!(@msg $msg $($value)*); + $msg + } + ] $($rest)*) + }; + // Message nested in [] literal + (@array $msg:ident $repeated:ident [$($vals:expr),+] $($msgtype:ident)::+ { $($value:tt)* }) => { + [ + $($vals),+ , + { + let mut $msg = $($msgtype)::+::new(); + proto_internal!(@merge $msg $($value)*); + proto_internal!(@msg $msg $($value)*); + $msg + } + ] + }; + // Message nested in [] literal with leading :: on type + (@array $msg:ident $repeated:ident [$($vals:expr),+] ::$($msgtype:ident)::+ { $($value:tt)* }) => { + [ + $($vals),+ , + { + let mut $msg = ::$($msgtype)::+::new(); + proto_internal!(@merge $msg $($value)*); + proto_internal!(@msg $msg $($value)*); + $msg + } + ] + }; + + // Message nested in array literal with trailing array items + (@array $msg:ident $repeated:ident [] $($msgtype:ident)::+ { $($value:tt)* }, $($rest:tt)*) => { + proto_internal!(@array $msg $repeated [ + { + let mut $msg = $($msgtype)::+::new(); + proto_internal!(@merge $msg $($value)*); + proto_internal!(@msg $msg $($value)*); + $msg + } + ] $($rest)*) + }; + // with leading :: + (@array $msg:ident $repeated:ident [] ::$($msgtype:ident)::+ { $($value:tt)* }, $($rest:tt)*) => { + proto_internal!(@array $msg $repeated [ + { + let mut $msg = ::$($msgtype)::+::new(); + proto_internal!(@merge $msg $($value)*); + proto_internal!(@msg $msg $($value)*); + $msg + } + ] $($rest)*) + }; + // Message nested in [] literal + (@array $msg:ident $repeated:ident [] $($msgtype:ident)::+ { $($value:tt)* }) => { + [ + { + let mut $msg = $($msgtype)::+::new(); + proto_internal!(@merge $msg $($value)*); + proto_internal!(@msg $msg $($value)*); + $msg + } + ] + }; + (@array $msg:ident $repeated:ident [] ::$($msgtype:ident)::+ { $($value:tt)* }) => { + [ + { + let mut $msg = ::$($msgtype)::+::new(); + proto_internal!(@merge $msg $($value)*); + proto_internal!(@msg $msg $($value)*); + $msg + } + ] + }; + + (@array $msg:ident $repeated:ident [$($vals:expr),+] $expr:expr, $($rest:tt)*) => { + proto_internal!(@array $msg $repeated [$($vals),+, $expr] $($rest)*) + }; + (@array $msg:ident $repeated:ident [$($vals:expr),+] $expr:expr) => { + [$($vals),+, $expr] + }; + (@array $msg:ident $repeated:ident [] $expr:expr, $($rest:tt)*) => { + proto_internal!(@array $msg $repeated [$expr] $($rest)*) + }; + (@array $msg:ident $repeated:ident [] $expr:expr) => { + [$expr] + }; + (@array $msg:ident $repeated:ident []) => { + [] }; // field: expr, diff --git a/rust/shared.rs b/rust/shared.rs index 403fcb0ea5..4494ee7019 100644 --- a/rust/shared.rs +++ b/rust/shared.rs @@ -91,3 +91,10 @@ impl fmt::Display for SerializeError { write!(f, "Couldn't serialize proto into bytes (depth too deep or missing required fields)") } } + +pub fn get_repeated_default_value( + _: __internal::Private, + _: repeated::RepeatedView<'_, T>, +) -> T { + Default::default() +} diff --git a/rust/test/shared/proto_macro_test.rs b/rust/test/shared/proto_macro_test.rs index 07c6f6bb80..876a5a47f4 100644 --- a/rust/test/shared/proto_macro_test.rs +++ b/rust/test/shared/proto_macro_test.rs @@ -94,6 +94,14 @@ fn single_nested_message() { }); assert_that!(msg.optional_nested_message().bb(), eq(42)); + // field above and below it + let msg = proto!(TestAllTypes { + optional_int32: 1, + optional_nested_message: __ { bb: 42 }, + optional_int64: 2 + }); + assert_that!(msg.optional_nested_message().bb(), eq(42)); + // test empty initializer let msg = proto!(TestAllTypes {}); assert_that!(msg.has_optional_nested_message(), eq(false)); @@ -104,6 +112,14 @@ fn single_nested_message() { optional_nested_message: unittest_rust_proto::test_all_types::NestedMessage {} }); assert_that!(msg.has_optional_nested_message(), eq(true)); + + let msg = proto!(::unittest_rust_proto::TestAllTypes { + optional_nested_message: ::unittest_rust_proto::test_all_types::NestedMessage {} + }); + assert_that!(msg.has_optional_nested_message(), eq(true)); + + let msg = proto!(::unittest_rust_proto::TestAllTypes { optional_nested_message: __ {} }); + assert_that!(msg.has_optional_nested_message(), eq(true)); } #[test] @@ -151,3 +167,38 @@ fn test_spread_nested_msg() { assert_that!(msg2.child().child().payload().optional_int32(), eq(42)); assert_that!(msg2.child().child().child().payload().optional_int32(), eq(43)); } + +#[test] +fn test_repeated_i32() { + let msg = proto!(TestAllTypes { repeated_int32: [1, 1 + 1, 3] }); + assert_that!(msg.repeated_int32().len(), eq(3)); + assert_that!(msg.repeated_int32().get(0).unwrap(), eq(1)); + assert_that!(msg.repeated_int32().get(1).unwrap(), eq(2)); + assert_that!(msg.repeated_int32().get(2).unwrap(), eq(3)); +} + +#[test] +fn test_repeated_msg() { + let msg2 = proto!(NestedTestAllTypes { payload: TestAllTypes { optional_int32: 1 } }); + let msg = proto!(NestedTestAllTypes { + child: NestedTestAllTypes { + repeated_child: [ + NestedTestAllTypes { payload: TestAllTypes { optional_int32: 0 } }, + msg2, + __ { payload: TestAllTypes { optional_int32: 2 } } + ] + }, + repeated_child: [ + __ { payload: __ { optional_int32: 1 } }, + NestedTestAllTypes { payload: TestAllTypes { optional_int32: 2 } } + ] + }); + assert_that!(msg.child().repeated_child().len(), eq(3)); + assert_that!(msg.child().repeated_child().get(0).unwrap().payload().optional_int32(), eq(0)); + assert_that!(msg.child().repeated_child().get(1).unwrap().payload().optional_int32(), eq(1)); + assert_that!(msg.child().repeated_child().get(2).unwrap().payload().optional_int32(), eq(2)); + + assert_that!(msg.repeated_child().len(), eq(2)); + assert_that!(msg.repeated_child().get(0).unwrap().payload().optional_int32(), eq(1)); + assert_that!(msg.repeated_child().get(1).unwrap().payload().optional_int32(), eq(2)); +}