diff --git a/rust/test/shared/accessors_proto3_test.rs b/rust/test/shared/accessors_proto3_test.rs index 3ab69d4157..c60a7c564c 100644 --- a/rust/test/shared/accessors_proto3_test.rs +++ b/rust/test/shared/accessors_proto3_test.rs @@ -307,74 +307,18 @@ fn test_oneof_accessors_view_long_lifetime() { fn test_oneof_enum_accessors() { use unittest_proto3::{ TestOneof2, - TestOneof2_::{Foo, NestedEnum}, + TestOneof2_::{Foo, FooCase, NestedEnum}, }; let mut msg = TestOneof2::new(); assert_that!(msg.foo_enum_opt(), eq(Optional::Unset(NestedEnum::Unknown))); assert_that!(msg.foo(), matches_pattern!(Foo::not_set(_))); + assert_that!(msg.foo_case(), matches_pattern!(FooCase::not_set)); msg.set_foo_enum(NestedEnum::Bar); assert_that!(msg.foo_enum_opt(), eq(Optional::Set(NestedEnum::Bar))); assert_that!(msg.foo(), matches_pattern!(Foo::FooEnum(eq(NestedEnum::Bar)))); -} - -#[test] -fn test_oneof_mut_accessors() { - use TestAllTypes_::OneofFieldMut::*; - - let mut msg = TestAllTypes::new(); - assert_that!(msg.oneof_field_mut(), matches_pattern!(not_set(_))); - - msg.set_oneof_uint32(7); - - match msg.oneof_field_mut() { - OneofUint32(mut v) => { - assert_that!(v.get(), eq(7)); - v.set(8); - assert_that!(v.get(), eq(8)); - } - f => panic!("unexpected field_mut type! {:?}", f), - } - - // Confirm that the mut write above applies to both the field accessor and the - // oneof view accessor. - assert_that!(msg.oneof_uint32_opt(), eq(Optional::Set(8))); - assert_that!( - msg.oneof_field(), - matches_pattern!(TestAllTypes_::OneofField::OneofUint32(eq(8))) - ); - - // Clearing a different field in the same oneof doesn't affect the other, set - // field. - msg.clear_oneof_bytes(); - assert_that!( - msg.oneof_field(), - matches_pattern!(TestAllTypes_::OneofField::OneofUint32(eq(8))) - ); - - msg.clear_oneof_uint32(); - assert_that!(msg.oneof_field_mut(), matches_pattern!(not_set(_))); - - msg.set_oneof_uint32(7); - msg.set_oneof_bytes(b"123"); - assert_that!(msg.oneof_field_mut(), matches_pattern!(OneofBytes(_))); -} - -#[test] -fn test_oneof_mut_enum_accessors() { - use unittest_proto3::{ - TestOneof2, - TestOneof2_::{FooMut, NestedEnum}, - }; - - let mut msg = TestOneof2::new(); - assert_that!(msg.foo_enum_opt(), eq(Optional::Unset(NestedEnum::Unknown))); - assert_that!(msg.foo_mut(), matches_pattern!(FooMut::not_set(_))); - - msg.set_foo_enum(NestedEnum::Bar); - assert_that!(msg.foo_enum_opt(), eq(Optional::Set(NestedEnum::Bar))); - assert_that!(msg.foo_mut(), matches_pattern!(FooMut::FooEnum(_))); + assert_that!(msg.foo_case(), matches_pattern!(FooCase::FooEnum)); } #[test] diff --git a/rust/test/shared/accessors_test.rs b/rust/test/shared/accessors_test.rs index b9a5565b8b..5970d05453 100644 --- a/rust/test/shared/accessors_test.rs +++ b/rust/test/shared/accessors_test.rs @@ -924,27 +924,32 @@ fn test_default_import_enum_accessors() { #[test] fn test_oneof_accessors() { use unittest_proto::TestOneof2; - use unittest_proto::TestOneof2_::{Foo::*, NestedEnum}; + use unittest_proto::TestOneof2_::{Foo::*, FooCase, NestedEnum}; let mut msg = TestOneof2::new(); assert_that!(msg.foo(), matches_pattern!(not_set(_))); + assert_that!(msg.foo_case(), eq(FooCase::not_set)); msg.foo_int_mut().set(7); assert_that!(msg.foo_int_opt(), eq(Optional::Set(7))); assert_that!(msg.foo(), matches_pattern!(FooInt(eq(7)))); + assert_that!(msg.foo_case(), eq(FooCase::FooInt)); msg.foo_int_mut().clear(); assert_that!(msg.foo_int_opt(), eq(Optional::Unset(0))); assert_that!(msg.foo(), matches_pattern!(not_set(_))); + assert_that!(msg.foo_case(), eq(FooCase::not_set)); msg.foo_int_mut().set(7); msg.foo_bytes_mut().set(b"123"); assert_that!(msg.foo_int_opt(), eq(Optional::Unset(0))); assert_that!(msg.foo(), matches_pattern!(FooBytes(eq(b"123")))); + assert_that!(msg.foo_case(), eq(FooCase::FooBytes)); msg.foo_enum_mut().set(NestedEnum::Foo); assert_that!(msg.foo(), matches_pattern!(FooEnum(eq(NestedEnum::Foo)))); + assert_that!(msg.foo_case(), eq(FooCase::FooEnum)); // Test the accessors or $Msg$Mut let mut msg_mut = msg.as_mut(); @@ -952,68 +957,21 @@ fn test_oneof_accessors() { msg_mut.foo_int_mut().set(7); msg_mut.foo_bytes_mut().set(b"123"); assert_that!(msg_mut.foo(), matches_pattern!(FooBytes(eq(b"123")))); + assert_that!(msg_mut.foo_case(), eq(FooCase::FooBytes)); assert_that!(msg_mut.foo_int_opt(), eq(Optional::Unset(0))); // Test the accessors on $Msg$View let msg_view = msg.as_view(); assert_that!(msg_view.foo(), matches_pattern!(FooBytes(eq(b"123")))); + assert_that!(msg_view.foo_case(), eq(FooCase::FooBytes)); assert_that!(msg_view.foo_int_opt(), eq(Optional::Unset(0))); // TODO: Add tests covering a message-type field in a oneof. } -#[test] -fn test_oneof_mut_accessors() { - use unittest_proto::TestOneof2; - use unittest_proto::TestOneof2_::{Foo, FooMut::*, NestedEnum}; - - let mut msg = TestOneof2::new(); - assert_that!(msg.foo_mut(), matches_pattern!(not_set(_))); - - msg.foo_int_mut().set(7); - - match msg.foo_mut() { - FooInt(mut v) => { - assert_that!(v.get(), eq(7)); - v.set(8); - assert_that!(v.get(), eq(8)); - } - f => panic!("unexpected field_mut type! {:?}", f), - } - - // Confirm that the mut write above applies to both the field accessor and the - // oneof view accessor. - assert_that!(msg.foo_int_opt(), eq(Optional::Set(8))); - assert_that!(msg.foo(), matches_pattern!(Foo::FooInt(_))); - - msg.foo_int_mut().clear(); - assert_that!(msg.foo_mut(), matches_pattern!(not_set(_))); - - msg.foo_int_mut().set(7); - msg.foo_bytes_mut().set(b"123"); - assert_that!(msg.foo_mut(), matches_pattern!(FooBytes(_))); - - msg.foo_enum_mut().set(NestedEnum::Baz); - assert_that!(msg.foo_mut(), matches_pattern!(FooEnum(_))); - - // Test the mut accessors or $Msg$Mut - let mut msg_mut = msg.as_mut(); - match msg_mut.foo_mut() { - FooEnum(mut v) => { - assert_that!(v.get(), eq(NestedEnum::Baz)); - v.set(NestedEnum::Bar); - assert_that!(v.get(), eq(NestedEnum::Bar)); - } - f => panic!("unexpected field_mut type! {:?}", f), - } - assert_that!(msg.foo_enum(), eq(NestedEnum::Bar)); - - // TODO: Add tests covering a message-type field in a oneof. -} - #[test] fn test_msg_oneof_default_accessors() { - use unittest_proto::TestOneof2_::{Bar::*, NestedEnum}; + use unittest_proto::TestOneof2_::{Bar::*, BarCase, NestedEnum}; let mut msg = unittest_proto::TestOneof2::new(); assert_that!(msg.bar(), matches_pattern!(not_set(_))); @@ -1021,78 +979,28 @@ fn test_msg_oneof_default_accessors() { msg.bar_int_mut().set(7); assert_that!(msg.bar_int_opt(), eq(Optional::Set(7))); assert_that!(msg.bar(), matches_pattern!(BarInt(eq(7)))); + assert_that!(msg.bar_case(), eq(BarCase::BarInt)); msg.bar_int_mut().clear(); assert_that!(msg.bar_int_opt(), eq(Optional::Unset(5))); assert_that!(msg.bar(), matches_pattern!(not_set(_))); + assert_that!(msg.bar_case(), eq(BarCase::not_set)); msg.bar_int_mut().set(7); msg.bar_bytes_mut().set(b"123"); assert_that!(msg.bar_int_opt(), eq(Optional::Unset(5))); assert_that!(msg.bar_enum_opt(), eq(Optional::Unset(NestedEnum::Bar))); assert_that!(msg.bar(), matches_pattern!(BarBytes(eq(b"123")))); + assert_that!(msg.bar_case(), eq(BarCase::BarBytes)); msg.bar_enum_mut().set(NestedEnum::Baz); assert_that!(msg.bar(), matches_pattern!(BarEnum(eq(NestedEnum::Baz)))); + assert_that!(msg.bar_case(), eq(BarCase::BarEnum)); assert_that!(msg.bar_int_opt(), eq(Optional::Unset(5))); // TODO: Add tests covering a message-type field in a oneof. } -#[test] -fn test_oneof_default_mut_accessors() { - use unittest_proto::TestOneof2_::{Bar, BarMut, BarMut::*, NestedEnum}; - - let mut msg = unittest_proto::TestOneof2::new(); - assert_that!(msg.bar_mut(), matches_pattern!(not_set(_))); - - msg.bar_int_mut().set(7); - - match msg.bar_mut() { - BarInt(mut v) => { - assert_that!(v.get(), eq(7)); - v.set(8); - assert_that!(v.get(), eq(8)); - } - f => panic!("unexpected field_mut type! {:?}", f), - } - - // Confirm that the mut write above applies to all three of: - // - The field accessor - // - The oneof mut accessor - // - The oneof view accessor - // And then each of the applicable cases on: - // - The owned msg directly - // - The msg as a $Msg$Mut - // - The msg as a $Msg$View - assert_that!(msg.bar_int_opt(), eq(Optional::Set(8))); - assert_that!(msg.bar_mut(), matches_pattern!(BarMut::BarInt(_))); - assert_that!(msg.bar(), matches_pattern!(Bar::BarInt(_))); - - let mut msg_mut = msg.as_mut(); - assert_that!(msg_mut.bar_int_opt(), eq(Optional::Set(8))); - assert_that!(msg_mut.bar_mut(), matches_pattern!(BarMut::BarInt(_))); - assert_that!(msg_mut.bar(), matches_pattern!(Bar::BarInt(_))); - - let msg_view = msg.as_view(); - assert_that!(msg_view.bar_int_opt(), eq(Optional::Set(8))); - // This test correctly fails to compile if this line is uncommented: - // assert_that!(msg_view.bar_mut(), matches_pattern!(BarMut::BarInt(_))); - assert_that!(msg_view.bar(), matches_pattern!(Bar::BarInt(_))); - - msg.bar_int_mut().clear(); - assert_that!(msg.bar_mut(), matches_pattern!(not_set(_))); - - msg.bar_int_mut().set(7); - msg.bar_bytes_mut().set(b"123"); - assert_that!(msg.bar_mut(), matches_pattern!(BarBytes(_))); - - msg.bar_enum_mut().set(NestedEnum::Baz); - assert_that!(msg.bar_mut(), matches_pattern!(BarEnum(_))); - - // TODO: Add tests covering a message-type field in a oneof. -} - #[test] fn test_set_message_from_view() { use protobuf::MutProxy; diff --git a/src/google/protobuf/compiler/rust/oneof.cc b/src/google/protobuf/compiler/rust/oneof.cc index 380fcfe175..f559ccd4d8 100644 --- a/src/google/protobuf/compiler/rust/oneof.cc +++ b/src/google/protobuf/compiler/rust/oneof.cc @@ -24,17 +24,12 @@ namespace protobuf { namespace compiler { namespace rust { -// We emit three Rust enums: +// For each oneof we emit two Rust enums with corresponding accessors: // - An enum acting as a tagged union that has each case holds a View<> of // each of the cases. Named as the one_of name in CamelCase. -// - An enum acting as a tagged union that has each case holds a Mut<> of -// each of the cases. Named as one_of name in CamelCase with "Mut" appended. -// [TODO: Mut not implemented yet]. -// - A simple enum whose cases have int values matching the cpp or upb's -// case enum. Named as the one_of camelcase with "Case" appended. -// All three contain cases matching the fields in the oneof CamelCased. -// The first and second are exposed in the API, the third is internal and -// used for interop with the Kernels in the generation of the other two. +// - A simple 'which oneof field is set' enum which directly maps to the +// underlying enum used for the 'cases' accessor in C++ or upb. Named as the +// one_of camelcase with "Case" appended. // // Example: // For this oneof: @@ -47,28 +42,30 @@ namespace rust { // // This will emit as the exposed API: // pub mod SomeMsg_ { -// // The 'view' struct (no suffix on the name) // pub enum SomeOneof<'msg> { // FieldA(i32) = 7, // FieldB(View<'msg, SomeMsg>) = 9, // not_set(std::marker::PhantomData<&'msg ()>) = 0 // } -// pub enum SomeOneofMut<'msg> { -// FieldA(Mut<'msg, i32>) = 7, -// FieldB(Mut<'msg, SomeMsg>) = 9, -// not_set(std::marker::PhantomData<&'msg ()>) = 0 +// +// #[repr(C)] +// pub enum SomeOneofCase { +// FieldA = 7, +// FieldB = 9, +// not_set = 0 // } // } // impl SomeMsg { // pub fn some_oneof(&self) -> SomeOneof {...} -// pub fn some_oneof_mut(&mut self) -> SomeOneofMut {...} +// pub fn some_oneof_case(&self) -> SomeOneofCase {...} // } // impl SomeMsgMut { // pub fn some_oneof(&self) -> SomeOneof {...} -// pub fn some_oneof_mut(&mut self) -> SomeOneofMut {...} +// pub fn some_oneof_case(&self) -> SomeOneofCase {...} // } // impl SomeMsgView { -// pub fn some_oneof(&self) -> SomeOneof {...} +// pub fn some_oneof(self) -> SomeOneof {...} +// pub fn some_oneof_case(self) -> SomeOneofCase {...} // } // // An additional "Case" enum which just reflects the corresponding slot numbers @@ -110,71 +107,28 @@ std::string RsTypeNameView(Context& ctx, const FieldDescriptor& field) { return ""; } -// A user-friendly rust type for a mutator of this field with lifetime 'msg. -std::string RsTypeNameMut(Context& ctx, const FieldDescriptor& field) { - if (field.options().has_ctype()) { - return ""; // TODO: b/308792377 - ctype fields not supported yet. - } - switch (GetRustFieldType(field)) { - case RustFieldType::INT32: - case RustFieldType::INT64: - case RustFieldType::UINT32: - case RustFieldType::UINT64: - case RustFieldType::FLOAT: - case RustFieldType::DOUBLE: - case RustFieldType::BOOL: - return absl::StrCat("::__pb::PrimitiveMut<'msg, ", RsTypePath(ctx, field), - ">"); - case RustFieldType::BYTES: - return "::__pb::BytesMut<'msg>"; - case RustFieldType::STRING: - return "::__pb::ProtoStrMut<'msg>"; - case RustFieldType::MESSAGE: - return absl::StrCat("::__pb::Mut<'msg, ", RsTypePath(ctx, field), ">"); - case RustFieldType::ENUM: - return absl::StrCat("::__pb::Mut<'msg, ", RsTypePath(ctx, field), ">"); - } - - ABSL_LOG(FATAL) << "Unexpected field type: " << field.type_name(); - return ""; -} - } // namespace void GenerateOneofDefinition(Context& ctx, const OneofDescriptor& oneof) { ctx.Emit( - {{"view_enum_name", OneofViewEnumRsName(oneof)}, - {"mut_enum_name", OneofMutEnumRsName(oneof)}, - {"view_fields", - [&] { - for (int i = 0; i < oneof.field_count(); ++i) { - auto& field = *oneof.field(i); - std::string rs_type = RsTypeNameView(ctx, field); - if (rs_type.empty()) { - continue; - } - ctx.Emit({{"name", OneofCaseRsName(field)}, - {"type", rs_type}, - {"number", std::to_string(field.number())}}, - R"rs($name$($type$) = $number$, - )rs"); - } - }}, - {"mut_fields", - [&] { - for (int i = 0; i < oneof.field_count(); ++i) { - auto& field = *oneof.field(i); - std::string rs_type = RsTypeNameMut(ctx, field); - if (rs_type.empty()) { - continue; - } - ctx.Emit({{"name", OneofCaseRsName(field)}, - {"type", rs_type}, - {"number", std::to_string(field.number())}}, - R"rs($name$($type$) = $number$, + { + {"view_enum_name", OneofViewEnumRsName(oneof)}, + {"view_fields", + [&] { + for (int i = 0; i < oneof.field_count(); ++i) { + auto& field = *oneof.field(i); + std::string rs_type = RsTypeNameView(ctx, field); + if (rs_type.empty()) { + continue; + } + ctx.Emit({{"name", OneofCaseRsName(field)}, + {"type", rs_type}, + {"number", std::to_string(field.number())}}, + R"rs($name$($type$) = $number$, )rs"); - } - }}}, + } + }}, + }, // TODO: Revisit if isize is the optimal repr for this enum. // TODO: not_set currently has phantom data just to avoid the // lifetime on the enum breaking compilation if there are zero supported @@ -190,18 +144,6 @@ void GenerateOneofDefinition(Context& ctx, const OneofDescriptor& oneof) { #[allow(non_camel_case_types)] not_set(std::marker::PhantomData<&'msg ()>) = 0 } - - #[non_exhaustive] - #[derive(Debug)] - #[allow(dead_code)] - #[repr(isize)] - pub enum $mut_enum_name$<'msg> { - $mut_fields$ - - #[allow(non_camel_case_types)] - not_set(std::marker::PhantomData<&'msg ()>) = 0 - } - )rs"); // Note: This enum is used as the Thunk return type for getting which case is @@ -221,7 +163,7 @@ void GenerateOneofDefinition(Context& ctx, const OneofDescriptor& oneof) { #[repr(C)] #[derive(Debug, Copy, Clone, PartialEq, Eq)] #[allow(dead_code)] - pub(super) enum $case_enum_name$ { + pub enum $case_enum_name$ { $cases$ #[allow(non_camel_case_types)] @@ -234,92 +176,45 @@ void GenerateOneofDefinition(Context& ctx, const OneofDescriptor& oneof) { void GenerateOneofAccessors(Context& ctx, const OneofDescriptor& oneof, AccessorCase accessor_case) { ctx.Emit( - {{"oneof_name", RsSafeName(oneof.name())}, - {"view_lifetime", ViewLifetime(accessor_case)}, - {"view_self", ViewReceiver(accessor_case)}, - {"view_enum_name", OneofViewEnumRsName(oneof)}, - {"mut_enum_name", OneofMutEnumRsName(oneof)}, - {"case_enum_name", OneofCaseEnumRsName(oneof)}, - {"view_cases", - [&] { - for (int i = 0; i < oneof.field_count(); ++i) { - auto& field = *oneof.field(i); - std::string rs_type = RsTypeNameView(ctx, field); - if (rs_type.empty()) { - continue; - } - ctx.Emit( - { - {"case", OneofCaseRsName(field)}, - {"rs_getter", RsSafeName(field.name())}, - {"type", rs_type}, - }, - R"rs( + { + {"oneof_name", RsSafeName(oneof.name())}, + {"view_lifetime", ViewLifetime(accessor_case)}, + {"self", ViewReceiver(accessor_case)}, + {"view_enum_name", OneofViewEnumRsName(oneof)}, + {"case_enum_name", OneofCaseEnumRsName(oneof)}, + {"view_cases", + [&] { + for (int i = 0; i < oneof.field_count(); ++i) { + auto& field = *oneof.field(i); + std::string rs_type = RsTypeNameView(ctx, field); + if (rs_type.empty()) { + continue; + } + ctx.Emit( + { + {"case", OneofCaseRsName(field)}, + {"rs_getter", RsSafeName(field.name())}, + {"type", rs_type}, + }, + R"rs( $Msg$_::$case_enum_name$::$case$ => $Msg$_::$view_enum_name$::$case$(self.$rs_getter$()), )rs"); - } - }}, - {"mut_cases", - [&] { - for (int i = 0; i < oneof.field_count(); ++i) { - auto& field = *oneof.field(i); - std::string rs_type = RsTypeNameMut(ctx, field); - if (rs_type.empty()) { - continue; - } - ctx.Emit( - {{"case", OneofCaseRsName(field)}, - {"rs_mut_getter", field.name() + "_mut"}, - {"type", rs_type}}, - // Any extra behavior needed to map the mut getter into the - // unwrapped Mut<>. Right now Message's _mut already returns - // the Mut directly, but for scalars the accessor will return - // an Optional which we then grab the mut by doing - // .try_into_mut().unwrap(). - // - // Note that this unwrap() is safe because the flow is: - // 1) Find out which oneof field is already set (if any) - // 2) If a field is set, call the corresponding field's _mut() - // and wrap the result in the SomeOneofMut enum. - // The unwrap() will only ever panic if the which oneof enum - // disagrees with the corresponding field presence which. - R"rs( - $Msg$_::$case_enum_name$::$case$ => - $Msg$_::$mut_enum_name$::$case$( - self.$rs_mut_getter$().try_into_mut().unwrap()), - )rs"); - } - }}, - {"case_thunk", ThunkName(ctx, oneof, "case")}, - {"getter", - [&] { - ctx.Emit({}, R"rs( - pub fn $oneof_name$($view_self$) -> $Msg$_::$view_enum_name$<$view_lifetime$> { - match unsafe { $case_thunk$(self.raw_msg()) } { - $view_cases$ - _ => $Msg$_::$view_enum_name$::not_set(std::marker::PhantomData) - } - } - )rs"); - }}, - {"getter_mut", - [&] { - if (accessor_case == AccessorCase::VIEW) { - return; - } - ctx.Emit({}, R"rs( - pub fn $oneof_name$_mut(&mut self) -> $Msg$_::$mut_enum_name$ { - match unsafe { $case_thunk$(self.raw_msg()) } { - $mut_cases$ - _ => $Msg$_::$mut_enum_name$::not_set(std::marker::PhantomData) + } + }}, + {"case_thunk", ThunkName(ctx, oneof, "case")}, + }, + R"rs( + pub fn $oneof_name$($self$) -> $Msg$_::$view_enum_name$<$view_lifetime$> { + match $self$.$oneof_name$_case() { + $view_cases$ + _ => $Msg$_::$view_enum_name$::not_set(std::marker::PhantomData) } } - )rs"); - }}}, - R"rs( - $getter$ - $getter_mut$ + + pub fn $oneof_name$_case($self$) -> $Msg$_::$case_enum_name$ { + unsafe { $case_thunk$(self.raw_msg()) } + } )rs"); }