diff --git a/rust/test/shared/simple_nested_test.rs b/rust/test/shared/simple_nested_test.rs index d40a7381de..298e47108d 100644 --- a/rust/test/shared/simple_nested_test.rs +++ b/rust/test/shared/simple_nested_test.rs @@ -33,25 +33,51 @@ fn test_nested_views() { #[test] fn test_nested_muts() { - // TODO: add actual mutation logic, this just peeks at InnerMut at - // the moment let mut outer_msg = Outer::new(); let inner_msg: InnerMut<'_> = outer_msg.inner_mut(); - assert_that!(inner_msg.double(), eq(0.0)); - assert_that!(inner_msg.float(), eq(0.0)); - assert_that!(inner_msg.int32(), eq(0)); - assert_that!(inner_msg.int64(), eq(0)); - assert_that!(inner_msg.uint32(), eq(0)); - assert_that!(inner_msg.uint64(), eq(0)); - assert_that!(inner_msg.sint32(), eq(0)); - assert_that!(inner_msg.sint64(), eq(0)); - assert_that!(inner_msg.fixed32(), eq(0)); - assert_that!(inner_msg.fixed64(), eq(0)); - assert_that!(inner_msg.sfixed32(), eq(0)); - assert_that!(inner_msg.sfixed64(), eq(0)); - assert_that!(inner_msg.bool(), eq(false)); - assert_that!(*inner_msg.string().as_bytes(), empty()); - assert_that!(*inner_msg.bytes(), empty()); + assert_that!( + inner_msg, + matches_pattern!(InnerMut{ + float(): eq(0.0), + double(): eq(0.0), + int32(): eq(0), + int64(): eq(0), + uint32(): eq(0), + uint64(): eq(0), + sint32(): eq(0), + sint64(): eq(0), + fixed32(): eq(0), + fixed64(): eq(0), + sfixed32(): eq(0), + sfixed64(): eq(0), + bool(): eq(false) + }) + ); + + inner_msg.double_mut().set(543.21); + assert_that!(inner_msg.double_mut().get(), eq(543.21)); + inner_msg.float_mut().set(1.23); + assert_that!(inner_msg.float_mut().get(), eq(1.23)); + inner_msg.int32_mut().set(12); + assert_that!(inner_msg.int32_mut().get(), eq(12)); + inner_msg.int64_mut().set(42); + assert_that!(inner_msg.int64_mut().get(), eq(42)); + inner_msg.uint32_mut().set(13); + assert_that!(inner_msg.uint32_mut().get(), eq(13)); + inner_msg.uint64_mut().set(5000); + assert_that!(inner_msg.uint64_mut().get(), eq(5000)); + inner_msg.sint32_mut().set(-2); + assert_that!(inner_msg.sint32_mut().get(), eq(-2)); + inner_msg.sint64_mut().set(322); + assert_that!(inner_msg.sint64_mut().get(), eq(322)); + inner_msg.fixed32_mut().set(77); + assert_that!(inner_msg.fixed32_mut().get(), eq(77)); + inner_msg.fixed64_mut().set(999); + assert_that!(inner_msg.fixed64_mut().get(), eq(999)); + inner_msg.bool_mut().set(true); + assert_that!(inner_msg.bool_mut().get(), eq(true)); + + // TODO: add mutation tests for strings and bytes } #[test] diff --git a/src/google/protobuf/compiler/rust/accessors/singular_message.cc b/src/google/protobuf/compiler/rust/accessors/singular_message.cc index cfa0d1fd0f..5985c10085 100644 --- a/src/google/protobuf/compiler/rust/accessors/singular_message.cc +++ b/src/google/protobuf/compiler/rust/accessors/singular_message.cc @@ -27,7 +27,7 @@ void SingularMessage::InMsgImpl(Context field) const { {"prefix", prefix}, {"field", field.desc().name()}, {"getter_thunk", Thunk(field, "get")}, - {"getter_mut_thunk", Thunk(field, "get_mut")}, + {"getter_mut_thunk", Thunk(field, "mutable")}, {"clearer_thunk", Thunk(field, "clear")}, { "view_body", @@ -59,18 +59,11 @@ void SingularMessage::InMsgImpl(Context field) const { [&] { if (field.is_upb()) { field.Emit({}, R"rs( - let submsg_opt = unsafe { $getter_thunk$(self.inner.msg) }; - match submsg_opt { - None => { - $prefix$Mut::new($pbi$::Private, - &mut self.inner, - $pbr$::ScratchSpace::zeroed_block($pbi$::Private)) - }, - Some(submsg) => { - $prefix$Mut::new($pbi$::Private, &mut self.inner, submsg) - } - } - )rs"); + let submsg = unsafe { + $getter_mut_thunk$(self.inner.msg, self.inner.arena.raw()) + }; + $prefix$Mut::new($pbi$::Private, &mut self.inner, submsg) + )rs"); } else { field.Emit({}, R"rs( let submsg = unsafe { $getter_mut_thunk$(self.inner.msg) }; @@ -98,8 +91,22 @@ void SingularMessage::InExternC(Context field) const { field.Emit( { {"getter_thunk", Thunk(field, "get")}, - {"getter_mut_thunk", Thunk(field, "get_mut")}, + {"getter_mut_thunk", Thunk(field, "mutable")}, {"clearer_thunk", Thunk(field, "clear")}, + {"getter_mut", + [&] { + if (field.is_cpp()) { + field.Emit( + R"rs( + fn $getter_mut_thunk$(raw_msg: $pbi$::RawMessage) + -> $pbi$::RawMessage;)rs"); + } else { + field.Emit( + R"rs(fn $getter_mut_thunk$(raw_msg: $pbi$::RawMessage, + arena: $pbi$::RawArena) + -> $pbi$::RawMessage;)rs"); + } + }}, {"ReturnType", [&] { if (field.is_cpp()) { @@ -114,7 +121,7 @@ void SingularMessage::InExternC(Context field) const { }, R"rs( fn $getter_thunk$(raw_msg: $pbi$::RawMessage) -> $ReturnType$; - fn $getter_mut_thunk$(raw_msg: $pbi$::RawMessage) -> $ReturnType$; + $getter_mut$ fn $clearer_thunk$(raw_msg: $pbi$::RawMessage); )rs"); } @@ -123,7 +130,7 @@ void SingularMessage::InThunkCc(Context field) const { field.Emit({{"QualifiedMsg", cpp::QualifiedClassName(field.desc().containing_type())}, {"getter_thunk", Thunk(field, "get")}, - {"getter_mut_thunk", Thunk(field, "get_mut")}, + {"getter_mut_thunk", Thunk(field, "mutable")}, {"clearer_thunk", Thunk(field, "clear")}, {"field", cpp::FieldName(&field.desc())}}, R"cc( diff --git a/src/google/protobuf/compiler/rust/message.cc b/src/google/protobuf/compiler/rust/message.cc index d317fed9c4..6bc142fc24 100644 --- a/src/google/protobuf/compiler/rust/message.cc +++ b/src/google/protobuf/compiler/rust/message.cc @@ -168,10 +168,12 @@ void GetterForViewOrMut(Context field, bool is_mut) { auto fieldName = field.desc().name(); auto fieldType = field.desc().type(); auto getter_thunk = Thunk(field, "get"); + auto setter_thunk = Thunk(field, "set"); + auto clearer_thunk = Thunk(field, "clear"); // If we're dealing with a Mut, the getter must be supplied // self.inner.msg() whereas a View has to be supplied self.msg auto self = is_mut ? "self.inner.msg()" : "self.msg"; - auto returnType = PrimitiveRsTypeName(field.desc()); + auto rsType = PrimitiveRsTypeName(field.desc()); if (fieldType == FieldDescriptor::TYPE_STRING) { field.Emit( @@ -179,10 +181,10 @@ void GetterForViewOrMut(Context field, bool is_mut) { {"field", fieldName}, {"self", self}, {"getter_thunk", getter_thunk}, - {"ReturnType", returnType}, + {"RsType", rsType}, }, R"rs( - pub fn r#$field$(&self) -> &$ReturnType$ { + pub fn r#$field$(&self) -> &$RsType$ { let s = unsafe { $getter_thunk$($self$).as_ref() }; unsafe { __pb::ProtoStr::from_utf8_unchecked(s) } } @@ -193,22 +195,50 @@ void GetterForViewOrMut(Context field, bool is_mut) { {"field", fieldName}, {"self", self}, {"getter_thunk", getter_thunk}, - {"ReturnType", returnType}, + {"RsType", rsType}, }, R"rs( - pub fn r#$field$(&self) -> &$ReturnType$ { + pub fn r#$field$(&self) -> &$RsType$ { unsafe { $getter_thunk$($self$).as_ref() } } )rs"); } else { field.Emit({{"field", fieldName}, {"getter_thunk", getter_thunk}, + {"setter_thunk", setter_thunk}, + {"clearer_thunk", clearer_thunk}, {"self", self}, - {"ReturnType", returnType}}, + {"RsType", rsType}, + {"maybe_mutator", + [&] { + // TODO: once the rust public api is accessible, + // by tooling, ensure that this only appears for the + // mutational pathway + if (is_mut && fieldType) { + field.Emit({}, R"rs( + pub fn r#$field$_mut(&self) -> $pb$::Mut<'_, $RsType$> { + static VTABLE: $pbi$::PrimitiveVTable<$RsType$> = + $pbi$::PrimitiveVTable::new( + $pbi$::Private, + $getter_thunk$, + $setter_thunk$); + $pb$::PrimitiveMut::from_singular( + $pbi$::Private, + unsafe { + $pbi$::RawVTableMutator::new($pbi$::Private, + self.inner, + &VTABLE) + }) + } + )rs"); + } + }}}, R"rs( - pub fn r#$field$(&self) -> $ReturnType$ { + pub fn r#$field$(&self) -> $RsType$ { unsafe { $getter_thunk$($self$) } } + + $maybe_mutator$ )rs"); } }