diff --git a/rust/cpp.rs b/rust/cpp.rs index 6a68bee691..6aefdbc04b 100644 --- a/rust/cpp.rs +++ b/rust/cpp.rs @@ -190,7 +190,7 @@ impl<'msg> MutatorMessageRef<'msg> { pub fn from_parent( _private: Private, - _parent_msg: &'msg mut MessageInner, + _parent_msg: MutatorMessageRef<'msg>, message_field_ptr: RawMessage, ) -> Self { MutatorMessageRef { msg: message_field_ptr, _phantom: PhantomData } diff --git a/rust/test/shared/simple_nested_test.rs b/rust/test/shared/simple_nested_test.rs index 4a26e58c33..134f7cc3dd 100644 --- a/rust/test/shared/simple_nested_test.rs +++ b/rust/test/shared/simple_nested_test.rs @@ -123,10 +123,26 @@ fn test_msg_from_outside() { } #[test] -fn test_recursive_msg() { +fn test_recursive_view() { let rec = nested_proto::nest::Recursive::new(); assert_that!(rec.num(), eq(0)); assert_that!(rec.rec().num(), eq(0)); assert_that!(rec.rec().rec().num(), eq(0)); // turtles all the way down... assert_that!(rec.rec().rec().rec().num(), eq(0)); // ... ad infinitum } + +#[test] +fn test_recursive_mut() { + let mut rec = nested_proto::nest::Recursive::new(); + let mut one = rec.rec_mut(); + let mut two = one.rec_mut(); + let mut three = two.rec_mut(); + let mut four = three.rec_mut(); + + four.num_mut().set(1); + assert_that!(four.num(), eq(1)); + + assert_that!(rec.num(), eq(0)); + assert_that!(rec.rec().rec().num(), eq(0)); + assert_that!(rec.rec().rec().rec().rec().num(), eq(1)); +} diff --git a/rust/upb.rs b/rust/upb.rs index 8041a23ad6..eaec859f4c 100644 --- a/rust/upb.rs +++ b/rust/upb.rs @@ -279,10 +279,10 @@ impl<'msg> MutatorMessageRef<'msg> { pub fn from_parent( _private: Private, - parent_msg: &'msg mut MessageInner, + parent_msg: MutatorMessageRef<'msg>, message_field_ptr: RawMessage, ) -> Self { - MutatorMessageRef { msg: message_field_ptr, arena: &parent_msg.arena } + MutatorMessageRef { msg: message_field_ptr, arena: parent_msg.arena } } pub fn msg(&self) -> RawMessage { diff --git a/src/google/protobuf/compiler/rust/BUILD.bazel b/src/google/protobuf/compiler/rust/BUILD.bazel index 13a86d29ef..03d7565191 100644 --- a/src/google/protobuf/compiler/rust/BUILD.bazel +++ b/src/google/protobuf/compiler/rust/BUILD.bazel @@ -110,6 +110,7 @@ cc_library( "accessors/unsupported_field.cc", ], hdrs = [ + "accessors/accessor_case.h", "accessors/accessor_generator.h", "accessors/accessors.h", ], diff --git a/src/google/protobuf/compiler/rust/accessors/accessor_case.h b/src/google/protobuf/compiler/rust/accessors/accessor_case.h new file mode 100644 index 0000000000..a68c55a6b9 --- /dev/null +++ b/src/google/protobuf/compiler/rust/accessors/accessor_case.h @@ -0,0 +1,21 @@ +// Protocol Buffers - Google's data interchange format +// Copyright 2023 Google LLC. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file or at +// https://developers.google.com/open-source/licenses/bsd + +#ifndef GOOGLE_PROTOBUF_COMPILER_RUST_ACCESSORS_ACCESSOR_CASE_H__ +#define GOOGLE_PROTOBUF_COMPILER_RUST_ACCESSORS_ACCESSOR_CASE_H__ + +// GenerateAccessorMsgImpl is reused for all three types of $Msg$, $Msg$Mut and +// $Msg$View; this enum signifies which case we are handling so corresponding +// adjustments can be made (for example: to not emit any mutation accessors +// on $Msg$View). +enum class AccessorCase { + OWNED, + MUT, + VIEW, +}; + +#endif // GOOGLE_PROTOBUF_COMPILER_RUST_ACCESSORS_ACCESSOR_CASE_H__ diff --git a/src/google/protobuf/compiler/rust/accessors/accessor_generator.h b/src/google/protobuf/compiler/rust/accessors/accessor_generator.h index c440206f36..b9d8684270 100644 --- a/src/google/protobuf/compiler/rust/accessors/accessor_generator.h +++ b/src/google/protobuf/compiler/rust/accessors/accessor_generator.h @@ -13,6 +13,7 @@ #include #include "absl/log/absl_check.h" +#include "google/protobuf/compiler/rust/accessors/accessor_case.h" #include "google/protobuf/compiler/rust/context.h" #include "google/protobuf/compiler/rust/naming.h" #include "google/protobuf/descriptor.h" @@ -38,18 +39,22 @@ class AccessorGenerator { static std::unique_ptr For(Context& ctx, const FieldDescriptor& field); - void GenerateMsgImpl(Context& ctx, const FieldDescriptor& field) const { + void GenerateMsgImpl(Context& ctx, const FieldDescriptor& field, + AccessorCase accessor_case) const { ctx.Emit({{"comment", FieldInfoComment(ctx, field)}}, R"rs( // $comment$ )rs"); - InMsgImpl(ctx, field); + InMsgImpl(ctx, field, accessor_case); + ctx.printer().PrintRaw("\n"); } void GenerateExternC(Context& ctx, const FieldDescriptor& field) const { InExternC(ctx, field); + ctx.printer().PrintRaw("\n"); } void GenerateThunkCc(Context& ctx, const FieldDescriptor& field) const { ABSL_CHECK(ctx.is_cpp()); InThunkCc(ctx, field); + ctx.printer().PrintRaw("\n"); } private: @@ -58,8 +63,10 @@ class AccessorGenerator { // functions. For example, consider calling `field.printer.WithVars()` as a // prologue to inject variables automatically. - // Called inside the main inherent `impl Msg {}` block. - virtual void InMsgImpl(Context& ctx, const FieldDescriptor& field) const {} + // Called inside the `impl Msg {}`, `impl MsgMut {}` and `impl MsgView` + // blocks. + virtual void InMsgImpl(Context& ctx, const FieldDescriptor& field, + AccessorCase accessor_case) const {} // Called inside of a message's `extern "C" {}` block. virtual void InExternC(Context& ctx, const FieldDescriptor& field) const {} @@ -72,7 +79,8 @@ class AccessorGenerator { class SingularScalar final : public AccessorGenerator { public: ~SingularScalar() override = default; - void InMsgImpl(Context& ctx, const FieldDescriptor& field) const override; + void InMsgImpl(Context& ctx, const FieldDescriptor& field, + AccessorCase accessor_case) const override; void InExternC(Context& ctx, const FieldDescriptor& field) const override; void InThunkCc(Context& ctx, const FieldDescriptor& field) const override; }; @@ -80,7 +88,8 @@ class SingularScalar final : public AccessorGenerator { class SingularString final : public AccessorGenerator { public: ~SingularString() override = default; - void InMsgImpl(Context& ctx, const FieldDescriptor& field) const override; + void InMsgImpl(Context& ctx, const FieldDescriptor& field, + AccessorCase accessor_case) const override; void InExternC(Context& ctx, const FieldDescriptor& field) const override; void InThunkCc(Context& ctx, const FieldDescriptor& field) const override; }; @@ -88,7 +97,8 @@ class SingularString final : public AccessorGenerator { class SingularMessage final : public AccessorGenerator { public: ~SingularMessage() override = default; - void InMsgImpl(Context& ctx, const FieldDescriptor& field) const override; + void InMsgImpl(Context& ctx, const FieldDescriptor& field, + AccessorCase accessor_case) const override; void InExternC(Context& ctx, const FieldDescriptor& field) const override; void InThunkCc(Context& ctx, const FieldDescriptor& field) const override; }; @@ -96,7 +106,8 @@ class SingularMessage final : public AccessorGenerator { class RepeatedScalar final : public AccessorGenerator { public: ~RepeatedScalar() override = default; - void InMsgImpl(Context& ctx, const FieldDescriptor& field) const override; + void InMsgImpl(Context& ctx, const FieldDescriptor& field, + AccessorCase accessor_case) const override; void InExternC(Context& ctx, const FieldDescriptor& field) const override; void InThunkCc(Context& ctx, const FieldDescriptor& field) const override; }; @@ -105,7 +116,8 @@ class UnsupportedField final : public AccessorGenerator { public: explicit UnsupportedField(std::string reason) : reason_(std::move(reason)) {} ~UnsupportedField() override = default; - void InMsgImpl(Context& ctx, const FieldDescriptor& field) const override; + void InMsgImpl(Context& ctx, const FieldDescriptor& field, + AccessorCase accessor_case) const override; private: std::string reason_; @@ -114,7 +126,8 @@ class UnsupportedField final : public AccessorGenerator { class Map final : public AccessorGenerator { public: ~Map() override = default; - void InMsgImpl(Context& ctx, const FieldDescriptor& field) const override; + void InMsgImpl(Context& ctx, const FieldDescriptor& field, + AccessorCase accessor_case) const override; void InExternC(Context& ctx, const FieldDescriptor& field) const override; void InThunkCc(Context& ctx, const FieldDescriptor& field) const override; }; diff --git a/src/google/protobuf/compiler/rust/accessors/accessors.cc b/src/google/protobuf/compiler/rust/accessors/accessors.cc index 2481c29918..40af904ff8 100644 --- a/src/google/protobuf/compiler/rust/accessors/accessors.cc +++ b/src/google/protobuf/compiler/rust/accessors/accessors.cc @@ -101,8 +101,9 @@ std::unique_ptr AccessorGeneratorFor( } // namespace -void GenerateAccessorMsgImpl(Context& ctx, const FieldDescriptor& field) { - AccessorGeneratorFor(ctx, field)->GenerateMsgImpl(ctx, field); +void GenerateAccessorMsgImpl(Context& ctx, const FieldDescriptor& field, + AccessorCase accessor_case) { + AccessorGeneratorFor(ctx, field)->GenerateMsgImpl(ctx, field, accessor_case); } void GenerateAccessorExternC(Context& ctx, const FieldDescriptor& field) { diff --git a/src/google/protobuf/compiler/rust/accessors/accessors.h b/src/google/protobuf/compiler/rust/accessors/accessors.h index 801bff2057..222c096c6b 100644 --- a/src/google/protobuf/compiler/rust/accessors/accessors.h +++ b/src/google/protobuf/compiler/rust/accessors/accessors.h @@ -8,6 +8,7 @@ #ifndef GOOGLE_PROTOBUF_COMPILER_RUST_ACCESSORS_ACCESSORS_H__ #define GOOGLE_PROTOBUF_COMPILER_RUST_ACCESSORS_ACCESSORS_H__ +#include "google/protobuf/compiler/rust/accessors/accessor_case.h" #include "google/protobuf/compiler/rust/context.h" #include "google/protobuf/descriptor.h" @@ -16,7 +17,11 @@ namespace protobuf { namespace compiler { namespace rust { -void GenerateAccessorMsgImpl(Context& ctx, const FieldDescriptor& field); +// Generates the Rust accessors: expected to be called once each for each +// Message, MessageMut and MessageView's impl. +void GenerateAccessorMsgImpl(Context& ctx, const FieldDescriptor& field, + AccessorCase accessor_case); + void GenerateAccessorExternC(Context& ctx, const FieldDescriptor& field); void GenerateAccessorThunkCc(Context& ctx, const FieldDescriptor& field); diff --git a/src/google/protobuf/compiler/rust/accessors/map.cc b/src/google/protobuf/compiler/rust/accessors/map.cc index 0a13e1694e..0850b65eac 100644 --- a/src/google/protobuf/compiler/rust/accessors/map.cc +++ b/src/google/protobuf/compiler/rust/accessors/map.cc @@ -6,6 +6,7 @@ // https://developers.google.com/open-source/licenses/bsd #include "google/protobuf/compiler/cpp/helpers.h" +#include "google/protobuf/compiler/rust/accessors/accessor_case.h" #include "google/protobuf/compiler/rust/accessors/accessor_generator.h" #include "google/protobuf/compiler/rust/context.h" #include "google/protobuf/compiler/rust/naming.h" @@ -17,7 +18,8 @@ namespace protobuf { namespace compiler { namespace rust { -void Map::InMsgImpl(Context& ctx, const FieldDescriptor& field) const { +void Map::InMsgImpl(Context& ctx, const FieldDescriptor& field, + AccessorCase accessor_case) const { auto& key_type = *field.message_type()->map_key(); auto& value_type = *field.message_type()->map_value(); @@ -53,6 +55,9 @@ void Map::InMsgImpl(Context& ctx, const FieldDescriptor& field) const { }}, {"getter_mut", [&] { + if (accessor_case == AccessorCase::VIEW) { + return; + } if (ctx.is_upb()) { ctx.Emit({}, R"rs( pub fn r#$field$_mut(&mut self) diff --git a/src/google/protobuf/compiler/rust/accessors/repeated_scalar.cc b/src/google/protobuf/compiler/rust/accessors/repeated_scalar.cc index a1d188d050..c1d0952202 100644 --- a/src/google/protobuf/compiler/rust/accessors/repeated_scalar.cc +++ b/src/google/protobuf/compiler/rust/accessors/repeated_scalar.cc @@ -7,6 +7,7 @@ #include "absl/strings/string_view.h" #include "google/protobuf/compiler/cpp/helpers.h" +#include "google/protobuf/compiler/rust/accessors/accessor_case.h" #include "google/protobuf/compiler/rust/accessors/accessor_generator.h" #include "google/protobuf/compiler/rust/context.h" #include "google/protobuf/compiler/rust/naming.h" @@ -17,8 +18,8 @@ namespace protobuf { namespace compiler { namespace rust { -void RepeatedScalar::InMsgImpl(Context& ctx, - const FieldDescriptor& field) const { +void RepeatedScalar::InMsgImpl(Context& ctx, const FieldDescriptor& field, + AccessorCase accessor_case) const { ctx.Emit({{"field", field.name()}, {"Scalar", RsTypePath(ctx, field)}, {"getter_thunk", ThunkName(ctx, field, "get")}, @@ -55,8 +56,11 @@ void RepeatedScalar::InMsgImpl(Context& ctx, } }}, {"clearer_thunk", ThunkName(ctx, field, "clear")}, - {"field_mutator_getter", + {"getter_mut", [&] { + if (accessor_case == AccessorCase::VIEW) { + return; + } if (ctx.is_upb()) { ctx.Emit({}, R"rs( pub fn r#$field$_mut(&mut self) -> $pb$::RepeatedMut<'_, $Scalar$> { @@ -70,7 +74,7 @@ void RepeatedScalar::InMsgImpl(Context& ctx, /* optional size pointer */ std::ptr::null(), self.arena().raw(), ), - &self.inner.arena, + self.arena(), ), ) } @@ -94,7 +98,7 @@ void RepeatedScalar::InMsgImpl(Context& ctx, }}}, R"rs( $getter$ - $field_mutator_getter$ + $getter_mut$ )rs"); } diff --git a/src/google/protobuf/compiler/rust/accessors/singular_message.cc b/src/google/protobuf/compiler/rust/accessors/singular_message.cc index acfaad7d7e..9862a26786 100644 --- a/src/google/protobuf/compiler/rust/accessors/singular_message.cc +++ b/src/google/protobuf/compiler/rust/accessors/singular_message.cc @@ -9,6 +9,7 @@ #include "absl/strings/string_view.h" #include "google/protobuf/compiler/cpp/helpers.h" +#include "google/protobuf/compiler/rust/accessors/accessor_case.h" #include "google/protobuf/compiler/rust/accessors/accessor_generator.h" #include "google/protobuf/compiler/rust/context.h" #include "google/protobuf/compiler/rust/naming.h" @@ -19,22 +20,20 @@ namespace protobuf { namespace compiler { namespace rust { -void SingularMessage::InMsgImpl(Context& ctx, - const FieldDescriptor& field) const { +void SingularMessage::InMsgImpl(Context& ctx, const FieldDescriptor& field, + AccessorCase accessor_case) const { // fully qualified message name with modules prefixed std::string msg_type = RsTypePath(ctx, field); - ctx.Emit( - { - {"msg_type", msg_type}, - {"field", field.name()}, - {"getter_thunk", ThunkName(ctx, field, "get")}, - {"getter_mut_thunk", ThunkName(ctx, field, "get_mut")}, - {"clearer_thunk", ThunkName(ctx, field, "clear")}, - { - "view_body", - [&] { - if (ctx.is_upb()) { - ctx.Emit({}, R"rs( + ctx.Emit({{"msg_type", msg_type}, + {"field", field.name()}, + {"getter_thunk", ThunkName(ctx, field, "get")}, + {"getter_mut_thunk", ThunkName(ctx, field, "get_mut")}, + {"clearer_thunk", ThunkName(ctx, field, "clear")}, + { + "getter_body", + [&] { + if (ctx.is_upb()) { + ctx.Emit({}, R"rs( let submsg = unsafe { $getter_thunk$(self.raw_msg()) }; //~ For upb, getters return null if the field is unset, so we need //~ to check for null and return the default instance manually. @@ -46,45 +45,64 @@ void SingularMessage::InMsgImpl(Context& ctx, Some(field) => $msg_type$View::new($pbi$::Private, field), } )rs"); - } else { - ctx.Emit({}, R"rs( + } else { + ctx.Emit({}, R"rs( //~ For C++ kernel, getters automatically return the //~ default_instance if the field is unset. let submsg = unsafe { $getter_thunk$(self.raw_msg()) }; $msg_type$View::new($pbi$::Private, submsg) )rs"); - } - }, - }, - {"submessage_mut", - [&] { - if (ctx.is_upb()) { + } + }, + }, + {"getter", + [&] { ctx.Emit({}, R"rs( + pub fn r#$field$(&self) -> $msg_type$View { + $getter_body$ + } + )rs"); + }}, + {"getter_mut_body", + [&] { + if (ctx.is_upb()) { + ctx.Emit({}, R"rs( let submsg = unsafe { $getter_mut_thunk$(self.raw_msg(), self.arena().raw()) }; - $msg_type$Mut::from_parent($pbi$::Private, &mut self.inner, submsg) + $msg_type$Mut::from_parent($pbi$::Private, self.as_mutator_message_ref(), submsg) )rs"); - } else { - ctx.Emit({}, R"rs( + } else { + ctx.Emit({}, R"rs( let submsg = unsafe { $getter_mut_thunk$(self.raw_msg()) }; - $msg_type$Mut::from_parent($pbi$::Private, &mut self.inner, submsg) + $msg_type$Mut::from_parent($pbi$::Private, self.as_mutator_message_ref(), submsg) )rs"); - } - }}, - }, - R"rs( - pub fn r#$field$(&self) -> $msg_type$View { - $view_body$ - } - - pub fn $field$_mut(&mut self) -> $msg_type$Mut { - $submessage_mut$ - } - - pub fn $field$_clear(&mut self) { - unsafe { $clearer_thunk$(self.raw_msg()) } - } + } + }}, + {"getter_mut", + [&] { + if (accessor_case == AccessorCase::VIEW) { + return; + } + ctx.Emit({}, R"rs( + pub fn $field$_mut(&mut self) -> $msg_type$Mut { + $getter_mut_body$ + })rs"); + }}, + {"clearer", + [&] { + if (accessor_case == AccessorCase::VIEW) { + return; + } + ctx.Emit({}, R"rs( + pub fn $field$_clear(&mut self) { + unsafe { $clearer_thunk$(self.raw_msg()) } + })rs"); + }}}, + R"rs( + $getter$ + $getter_mut$ + $clearer$ )rs"); } diff --git a/src/google/protobuf/compiler/rust/accessors/singular_scalar.cc b/src/google/protobuf/compiler/rust/accessors/singular_scalar.cc index 129142ebb9..8458deb35d 100644 --- a/src/google/protobuf/compiler/rust/accessors/singular_scalar.cc +++ b/src/google/protobuf/compiler/rust/accessors/singular_scalar.cc @@ -9,6 +9,7 @@ #include "absl/strings/string_view.h" #include "google/protobuf/compiler/cpp/helpers.h" +#include "google/protobuf/compiler/rust/accessors/accessor_case.h" #include "google/protobuf/compiler/rust/accessors/accessor_generator.h" #include "google/protobuf/compiler/rust/accessors/helpers.h" #include "google/protobuf/compiler/rust/context.h" @@ -20,8 +21,8 @@ namespace protobuf { namespace compiler { namespace rust { -void SingularScalar::InMsgImpl(Context& ctx, - const FieldDescriptor& field) const { +void SingularScalar::InMsgImpl(Context& ctx, const FieldDescriptor& field, + AccessorCase accessor_case) const { ctx.Emit( { {"field", field.name()}, @@ -53,8 +54,11 @@ void SingularScalar::InMsgImpl(Context& ctx, {"getter_thunk", ThunkName(ctx, field, "get")}, {"setter_thunk", ThunkName(ctx, field, "set")}, {"clearer_thunk", ThunkName(ctx, field, "clear")}, - {"field_mutator_getter", + {"getter_mut", [&] { + if (accessor_case == AccessorCase::VIEW) { + return; + } if (field.has_presence()) { ctx.Emit({}, R"rs( pub fn r#$field$_mut(&mut self) -> $pb$::FieldEntry<'_, $Scalar$> { @@ -71,7 +75,7 @@ void SingularScalar::InMsgImpl(Context& ctx, let has = $hazzer_thunk$(self.raw_msg()); $pbi$::new_vtable_field_entry::<$Scalar$>( $pbi$::Private, - $pbr$::MutatorMessageRef::new($pbi$::Private, &mut self.inner), + self.as_mutator_message_ref(), &VTABLE, has, ) @@ -98,9 +102,7 @@ void SingularScalar::InMsgImpl(Context& ctx, $pbi$::Private, $pbi$::RawVTableMutator::new( $pbi$::Private, - $pbr$::MutatorMessageRef::new( - $pbi$::Private, &mut self.inner - ), + self.as_mutator_message_ref(), &VTABLE, ), ) @@ -113,7 +115,7 @@ void SingularScalar::InMsgImpl(Context& ctx, R"rs( $getter$ $getter_opt$ - $field_mutator_getter$ + $getter_mut$ )rs"); } diff --git a/src/google/protobuf/compiler/rust/accessors/singular_string.cc b/src/google/protobuf/compiler/rust/accessors/singular_string.cc index e4357ecc8e..1f86d33400 100644 --- a/src/google/protobuf/compiler/rust/accessors/singular_string.cc +++ b/src/google/protobuf/compiler/rust/accessors/singular_string.cc @@ -9,6 +9,7 @@ #include "absl/strings/string_view.h" #include "google/protobuf/compiler/cpp/helpers.h" +#include "google/protobuf/compiler/rust/accessors/accessor_case.h" #include "google/protobuf/compiler/rust/accessors/accessor_generator.h" #include "google/protobuf/compiler/rust/accessors/helpers.h" #include "google/protobuf/compiler/rust/context.h" @@ -20,8 +21,8 @@ namespace protobuf { namespace compiler { namespace rust { -void SingularString::InMsgImpl(Context& ctx, - const FieldDescriptor& field) const { +void SingularString::InMsgImpl(Context& ctx, const FieldDescriptor& field, + AccessorCase accessor_case) const { std::string hazzer_thunk = ThunkName(ctx, field, "has"); std::string getter_thunk = ThunkName(ctx, field, "get"); std::string setter_thunk = ThunkName(ctx, field, "set"); @@ -63,6 +64,9 @@ void SingularString::InMsgImpl(Context& ctx, }}, {"field_mutator_getter", [&] { + if (accessor_case == AccessorCase::VIEW) { + return; + } if (field.has_presence()) { ctx.Emit( { @@ -102,8 +106,7 @@ void SingularString::InMsgImpl(Context& ctx, let has = $hazzer_thunk$(self.raw_msg()); $pbi$::new_vtable_field_entry( $pbi$::Private, - $pbr$::MutatorMessageRef::new( - $pbi$::Private, &mut self.inner), + self.as_mutator_message_ref(), &VTABLE, has, ) @@ -129,8 +132,7 @@ void SingularString::InMsgImpl(Context& ctx, $pbi$::Private, $pbi$::RawVTableMutator::new( $pbi$::Private, - $pbr$::MutatorMessageRef::new( - $pbi$::Private, &mut self.inner), + self.as_mutator_message_ref(), &VTABLE, ) ) diff --git a/src/google/protobuf/compiler/rust/accessors/unsupported_field.cc b/src/google/protobuf/compiler/rust/accessors/unsupported_field.cc index 591298a2bc..a04e778cd0 100644 --- a/src/google/protobuf/compiler/rust/accessors/unsupported_field.cc +++ b/src/google/protobuf/compiler/rust/accessors/unsupported_field.cc @@ -6,6 +6,7 @@ // https://developers.google.com/open-source/licenses/bsd #include "absl/strings/string_view.h" +#include "google/protobuf/compiler/rust/accessors/accessor_case.h" #include "google/protobuf/compiler/rust/accessors/accessor_generator.h" #include "google/protobuf/compiler/rust/context.h" #include "google/protobuf/descriptor.h" @@ -15,8 +16,8 @@ namespace protobuf { namespace compiler { namespace rust { -void UnsupportedField::InMsgImpl(Context& ctx, - const FieldDescriptor& field) const { +void UnsupportedField::InMsgImpl(Context& ctx, const FieldDescriptor& field, + AccessorCase accessor_case) const { ctx.Emit({{"reason", reason_}}, R"rs( // Unsupported! :( Reason: $reason$ )rs"); diff --git a/src/google/protobuf/compiler/rust/message.cc b/src/google/protobuf/compiler/rust/message.cc index 654da7c90c..d3098748c9 100644 --- a/src/google/protobuf/compiler/rust/message.cc +++ b/src/google/protobuf/compiler/rust/message.cc @@ -11,10 +11,10 @@ #include "absl/log/absl_check.h" #include "absl/log/absl_log.h" -#include "absl/strings/str_format.h" #include "absl/strings/string_view.h" #include "google/protobuf/compiler/cpp/helpers.h" #include "google/protobuf/compiler/cpp/names.h" +#include "google/protobuf/compiler/rust/accessors/accessor_case.h" #include "google/protobuf/compiler/rust/accessors/accessors.h" #include "google/protobuf/compiler/rust/context.h" #include "google/protobuf/compiler/rust/enum.h" @@ -178,10 +178,6 @@ void MessageDrop(Context& ctx, const Descriptor& msg) { )rs"); } -bool IsStringOrBytes(FieldDescriptor::Type t) { - return t == FieldDescriptor::TYPE_STRING || t == FieldDescriptor::TYPE_BYTES; -} - void MessageSettableValue(Context& ctx, const Descriptor& msg) { switch (ctx.opts().kernel) { case Kernel::kCpp: @@ -217,131 +213,6 @@ void MessageSettableValue(Context& ctx, const Descriptor& msg) { ABSL_LOG(FATAL) << "unreachable"; } -void GetterForViewOrMut(Context& ctx, const FieldDescriptor& field, - bool is_mut) { - auto fieldName = field.name(); - auto fieldType = field.type(); - auto getter_thunk = ThunkName(ctx, field, "get"); - auto setter_thunk = ThunkName(ctx, field, "set"); - - if (fieldType == FieldDescriptor::TYPE_MESSAGE) { - const Descriptor& msg = *field.message_type(); - // TODO: support messages which are defined in other crates. - if (!IsInCurrentlyGeneratingCrate(ctx, msg)) { - return; - } - auto prefix = RsTypePath(ctx, field); - ctx.Emit( - { - {"prefix", prefix}, - {"field", fieldName}, - {"getter_thunk", getter_thunk}, - // TODO: dedupe with singular_message.cc - { - "view_body", - [&] { - if (ctx.is_upb()) { - ctx.Emit({}, R"rs( - let submsg = unsafe { $getter_thunk$(self.raw_msg()) }; - match submsg { - None => $prefix$View::new($pbi$::Private, - $pbr$::ScratchSpace::zeroed_block($pbi$::Private)), - Some(field) => $prefix$View::new($pbi$::Private, field), - } - )rs"); - } else { - ctx.Emit({}, R"rs( - let submsg = unsafe { $getter_thunk$(self.raw_msg()) }; - $prefix$View::new($pbi$::Private, submsg) - )rs"); - } - }, - }, - }, - R"rs( - pub fn r#$field$(&self) -> $prefix$View { - $view_body$ - } - )rs"); - return; - } - - auto rsType = RsTypePath(ctx, field); - auto asRef = IsStringOrBytes(fieldType) ? ".as_ref()" : ""; - auto vtable = - IsStringOrBytes(fieldType) ? "BytesMutVTable" : "PrimitiveVTable"; - // PrimitiveVtable is parameterized based on the underlying primitive, like - // u32 so we need to provide this additional type arg - auto optionalTypeArgs = - IsStringOrBytes(fieldType) ? "" : absl::StrFormat("<%s>", rsType); - // need to stuff ProtoStr and [u8] behind a reference since they are DSTs - auto stringTransform = - IsStringOrBytes(fieldType) - ? "unsafe { __pb::ProtoStr::from_utf8_unchecked(res).into() }" - : "res"; - - // TODO: support enums which are defined in other crates. - auto enum_ = field.enum_type(); - if (enum_ != nullptr && !IsInCurrentlyGeneratingCrate(ctx, *enum_)) { - return; - } - - ctx.Emit({{"field", fieldName}, - {"getter_thunk", getter_thunk}, - {"setter_thunk", setter_thunk}, - {"RsType", rsType}, - {"as_ref", asRef}, - {"vtable", vtable}, - {"optional_type_args", optionalTypeArgs}, - {"string_transform", stringTransform}, - {"maybe_mutator", - [&] { - // TODO: check mutational pathway genn'd correctly - if (is_mut) { - ctx.Emit({}, R"rs( - pub fn r#$field$_mut(&mut self) -> $pb$::Mut<'_, $RsType$> { - static VTABLE: $pbi$::$vtable$$optional_type_args$ = - $pbi$::$vtable$::new( - $pbi$::Private, - $getter_thunk$, - $setter_thunk$); - unsafe { - <$pb$::Mut<$RsType$>>::from_inner( - $pbi$::Private, - $pbi$::RawVTableMutator::new( - $pbi$::Private, - self.inner, - &VTABLE - ), - ) - } - } - )rs"); - } - }}}, - R"rs( - pub fn r#$field$(&self) -> $pb$::View<'_, $RsType$> { - let res = unsafe { $getter_thunk$(self.raw_msg())$as_ref$ }; - $string_transform$ - } - - $maybe_mutator$ - )rs"); -} - -void AccessorsForViewOrMut(Context& ctx, const Descriptor& msg, bool is_mut) { - for (int i = 0; i < msg.field_count(); ++i) { - const FieldDescriptor& field = *msg.field(i); - if (field.is_repeated()) continue; - // TODO - add cord support - if (field.options().has_ctype()) continue; - // TODO - if (field.type() == FieldDescriptor::TYPE_GROUP) continue; - GetterForViewOrMut(ctx, field, is_mut); - ctx.printer().PrintRaw("\n"); - } -} - } // namespace void GenerateRs(Context& ctx, const Descriptor& msg) { @@ -358,29 +229,26 @@ void GenerateRs(Context& ctx, const Descriptor& msg) { {"accessor_fns", [&] { for (int i = 0; i < msg.field_count(); ++i) { - GenerateAccessorMsgImpl(ctx, *msg.field(i)); - ctx.printer().PrintRaw("\n"); + GenerateAccessorMsgImpl(ctx, *msg.field(i), + AccessorCase::OWNED); } }}, {"oneof_accessor_fns", [&] { for (int i = 0; i < msg.real_oneof_decl_count(); ++i) { GenerateOneofAccessors(ctx, *msg.real_oneof_decl(i)); - ctx.printer().PrintRaw("\n"); } }}, {"accessor_externs", [&] { for (int i = 0; i < msg.field_count(); ++i) { GenerateAccessorExternC(ctx, *msg.field(i)); - ctx.printer().PrintRaw("\n"); } }}, {"oneof_externs", [&] { for (int i = 0; i < msg.real_oneof_decl_count(); ++i) { GenerateOneofExternC(ctx, *msg.real_oneof_decl(i)); - ctx.printer().PrintRaw("\n"); } }}, {"nested_in_msg", @@ -442,9 +310,18 @@ void GenerateRs(Context& ctx, const Descriptor& msg) { } }}, {"accessor_fns_for_views", - [&] { AccessorsForViewOrMut(ctx, msg, false); }}, + [&] { + for (int i = 0; i < msg.field_count(); ++i) { + GenerateAccessorMsgImpl(ctx, *msg.field(i), + AccessorCase::VIEW); + } + }}, {"accessor_fns_for_muts", - [&] { AccessorsForViewOrMut(ctx, msg, true); }}, + [&] { + for (int i = 0; i < msg.field_count(); ++i) { + GenerateAccessorMsgImpl(ctx, *msg.field(i), AccessorCase::MUT); + } + }}, {"settable_impl", [&] { MessageSettableValue(ctx, msg); }}}, R"rs( #[allow(non_camel_case_types)] @@ -519,8 +396,9 @@ void GenerateRs(Context& ctx, const Descriptor& msg) { #[allow(dead_code)] impl<'a> $Msg$Mut<'a> { #[doc(hidden)] - pub fn from_parent(_private: $pbi$::Private, - parent: &'a mut $pbr$::MessageInner, + pub fn from_parent( + _private: $pbi$::Private, + parent: $pbr$::MutatorMessageRef<'a>, msg: $pbi$::RawMessage) -> Self { Self { @@ -528,6 +406,7 @@ void GenerateRs(Context& ctx, const Descriptor& msg) { $pbi$::Private, parent, msg) } } + #[doc(hidden)] pub fn new(_private: $pbi$::Private, msg: &'a mut $pbr$::MessageInner) -> Self { Self{ inner: $pbr$::MutatorMessageRef::new(_private, msg) } @@ -537,6 +416,10 @@ void GenerateRs(Context& ctx, const Descriptor& msg) { self.inner.msg() } + fn as_mutator_message_ref(&mut self) -> $pbr$::MutatorMessageRef<'a> { + self.inner + } + $raw_arena_getter_for_msgmut$ $accessor_fns_for_muts$ @@ -575,6 +458,10 @@ void GenerateRs(Context& ctx, const Descriptor& msg) { self.inner.msg } + fn as_mutator_message_ref(&mut self) -> $pbr$::MutatorMessageRef { + $pbr$::MutatorMessageRef::new($pbi$::Private, &mut self.inner) + } + $raw_arena_getter_for_message$ pub fn serialize(&self) -> $pbr$::SerializedData {