diff --git a/rust/test/shared/accessors_proto3_test.rs b/rust/test/shared/accessors_proto3_test.rs index 3eb6e5df61..8800a097a9 100644 --- a/rust/test/shared/accessors_proto3_test.rs +++ b/rust/test/shared/accessors_proto3_test.rs @@ -205,3 +205,37 @@ fn test_oneof_accessors() { assert_that!(msg.oneof_field(), matches_pattern!(OneofBytes(eq(b"123")))); } + +#[test] +fn test_oneof_mut_accessors() { + use TestAllTypes_::OneofFieldMut::*; + + let mut msg = TestAllTypes::new(); + assert_that!(msg.oneof_field_mut(), matches_pattern!(not_set(_))); + + msg.oneof_uint32_set(Some(7)); + + match msg.oneof_field_mut() { + OneofUint32(mut v) => { + assert_eq!(v.get(), 7); + v.set(8); + assert_eq!(v.get(), 8); + } + f => panic!("unexpected field_mut type! {:?}", f), + } + + // Confirm that the mut write above applies to both the field accessor and the + // oneof view accessor. + assert_that!(msg.oneof_uint32_opt(), eq(Optional::Set(8))); + assert_that!( + msg.oneof_field(), + matches_pattern!(TestAllTypes_::OneofField::OneofUint32(eq(8))) + ); + + msg.oneof_uint32_set(None); + assert_that!(msg.oneof_field_mut(), matches_pattern!(not_set(_))); + + msg.oneof_uint32_set(Some(7)); + msg.oneof_bytes_mut().set(b"123"); + assert_that!(msg.oneof_field_mut(), matches_pattern!(OneofBytes(_))); +} diff --git a/rust/test/shared/accessors_test.rs b/rust/test/shared/accessors_test.rs index 2b207ac11c..e5949fffab 100644 --- a/rust/test/shared/accessors_test.rs +++ b/rust/test/shared/accessors_test.rs @@ -696,6 +696,40 @@ fn test_oneof_accessors() { assert_that!(msg.oneof_field(), matches_pattern!(OneofBytes(eq(b"123")))); } +#[test] +fn test_oneof_mut_accessors() { + use TestAllTypes_::OneofFieldMut::*; + + let mut msg = TestAllTypes::new(); + assert_that!(msg.oneof_field_mut(), matches_pattern!(not_set(_))); + + msg.oneof_uint32_set(Some(7)); + + match msg.oneof_field_mut() { + OneofUint32(mut v) => { + assert_eq!(v.get(), 7); + v.set(8); + assert_eq!(v.get(), 8); + } + f => panic!("unexpected field_mut type! {:?}", f), + } + + // Confirm that the mut write above applies to both the field accessor and the + // oneof view accessor. + assert_that!(msg.oneof_uint32_opt(), eq(Optional::Set(8))); + assert_that!( + msg.oneof_field(), + matches_pattern!(TestAllTypes_::OneofField::OneofUint32(eq(8))) + ); + + msg.oneof_uint32_set(None); + assert_that!(msg.oneof_field_mut(), matches_pattern!(not_set(_))); + + msg.oneof_uint32_set(Some(7)); + msg.oneof_bytes_mut().set(b"123"); + assert_that!(msg.oneof_field_mut(), matches_pattern!(OneofBytes(_))); +} + macro_rules! generate_repeated_numeric_test { ($(($t: ty, $field: ident)),*) => { paste! { $( diff --git a/src/google/protobuf/compiler/rust/oneof.cc b/src/google/protobuf/compiler/rust/oneof.cc index 33836a02f9..7f2671d7ea 100644 --- a/src/google/protobuf/compiler/rust/oneof.cc +++ b/src/google/protobuf/compiler/rust/oneof.cc @@ -246,15 +246,26 @@ void GenerateOneofAccessors(Context oneof) { if (rs_type.empty()) { continue; } - // TODO: Allow mut. - /*oneof.Emit({ + oneof.Emit( + { {"case", ToCamelCase(field->name())}, - {"rs_getter", field->name() + "_mut"}, + {"rs_mut_getter", field->name() + "_mut"}, {"type", rs_type}, }, + + // The flow here is: + // 1) First find out which oneof field is already set (if any) + // 2) If a field is set, call the corresponding field's _mut() + // and wrap that Mut<> in the SomeOneofMut eum. + // During step 2 this code uses try_into_mut().unwrap() instead + // of .or_default() so that it will panic if step 1 says that + // the field is set, but then the _mut() accessor for the + // corresponding field shows as unset; if that happened it would + // imply a severe error in protobuf code; .or_default() would + // silently continue and cause the field to become set on the + // message, which is not the intended behavior. R"rs($Msg$_::$case_enum_name$::$case$ => - $Msg$_::$mut_enum_name$::$case$(self.$rs_getter$()), )rs"); - */ + $Msg$_::$mut_enum_name$::$case$(self.$rs_mut_getter$().try_into_mut().unwrap()), )rs"); } }}, {"case_thunk", Thunk(oneof, "case")}},