From 542ca772fad96ca085adbca2d1d45d7a11a839cc Mon Sep 17 00:00:00 2001 From: Protobuf Team Bot Date: Mon, 18 Dec 2023 15:39:46 -0800 Subject: [PATCH] Decouple Context from the Descriptor PiperOrigin-RevId: 592029759 --- .../rust/accessors/accessor_generator.h | 63 ++-- .../compiler/rust/accessors/accessors.cc | 35 +- .../compiler/rust/accessors/accessors.h | 6 +- .../compiler/rust/accessors/helpers.cc | 42 ++- .../compiler/rust/accessors/helpers.h | 3 +- .../protobuf/compiler/rust/accessors/map.cc | 87 +++-- .../rust/accessors/repeated_scalar.cc | 124 +++---- .../rust/accessors/singular_message.cc | 85 ++--- .../rust/accessors/singular_scalar.cc | 112 +++--- .../rust/accessors/singular_string.cc | 139 ++++---- .../rust/accessors/unsupported_field.cc | 7 +- src/google/protobuf/compiler/rust/context.cc | 9 +- src/google/protobuf/compiler/rust/context.h | 42 +-- .../protobuf/compiler/rust/generator.cc | 165 +++++---- src/google/protobuf/compiler/rust/message.cc | 337 +++++++++--------- src/google/protobuf/compiler/rust/message.h | 4 +- src/google/protobuf/compiler/rust/naming.cc | 120 +++---- src/google/protobuf/compiler/rust/naming.h | 26 +- src/google/protobuf/compiler/rust/oneof.cc | 156 ++++---- src/google/protobuf/compiler/rust/oneof.h | 8 +- 20 files changed, 768 insertions(+), 802 deletions(-) diff --git a/src/google/protobuf/compiler/rust/accessors/accessor_generator.h b/src/google/protobuf/compiler/rust/accessors/accessor_generator.h index 442de3c7cc..6aad7bba23 100644 --- a/src/google/protobuf/compiler/rust/accessors/accessor_generator.h +++ b/src/google/protobuf/compiler/rust/accessors/accessor_generator.h @@ -26,25 +26,26 @@ class AccessorGenerator { AccessorGenerator() = default; virtual ~AccessorGenerator() = default; - AccessorGenerator(const AccessorGenerator &) = delete; - AccessorGenerator(AccessorGenerator &&) = delete; - AccessorGenerator &operator=(const AccessorGenerator &) = delete; - AccessorGenerator &operator=(AccessorGenerator &&) = delete; + AccessorGenerator(const AccessorGenerator&) = delete; + AccessorGenerator(AccessorGenerator&&) = delete; + AccessorGenerator& operator=(const AccessorGenerator&) = delete; + AccessorGenerator& operator=(AccessorGenerator&&) = delete; // Constructs a generator for the given field. // // Returns `nullptr` if there is no known generator for this field. - static std::unique_ptr For(Context field); + static std::unique_ptr For(Context& ctx, + const FieldDescriptor& field); - void GenerateMsgImpl(Context field) const { - InMsgImpl(field); + void GenerateMsgImpl(Context& ctx, const FieldDescriptor& field) const { + InMsgImpl(ctx, field); } - void GenerateExternC(Context field) const { - InExternC(field); + void GenerateExternC(Context& ctx, const FieldDescriptor& field) const { + InExternC(ctx, field); } - void GenerateThunkCc(Context field) const { - ABSL_CHECK(field.is_cpp()); - InThunkCc(field); + void GenerateThunkCc(Context& ctx, const FieldDescriptor& field) const { + ABSL_CHECK(ctx.is_cpp()); + InThunkCc(ctx, field); } private: @@ -54,53 +55,53 @@ class AccessorGenerator { // prologue to inject variables automatically. // Called inside the main inherent `impl Msg {}` block. - virtual void InMsgImpl(Context field) const {} + virtual void InMsgImpl(Context& ctx, const FieldDescriptor& field) const {} // Called inside of a message's `extern "C" {}` block. - virtual void InExternC(Context field) const {} + virtual void InExternC(Context& ctx, const FieldDescriptor& field) const {} // Called inside of an `extern "C" {}` block in the `.thunk.cc` file, if such // a file is being generated. - virtual void InThunkCc(Context field) const {} + virtual void InThunkCc(Context& ctx, const FieldDescriptor& field) const {} }; class SingularScalar final : public AccessorGenerator { public: ~SingularScalar() override = default; - void InMsgImpl(Context field) const override; - void InExternC(Context field) const override; - void InThunkCc(Context field) const override; + void InMsgImpl(Context& ctx, const FieldDescriptor& field) const override; + void InExternC(Context& ctx, const FieldDescriptor& field) const override; + void InThunkCc(Context& ctx, const FieldDescriptor& field) const override; }; class SingularString final : public AccessorGenerator { public: ~SingularString() override = default; - void InMsgImpl(Context field) const override; - void InExternC(Context field) const override; - void InThunkCc(Context field) const override; + void InMsgImpl(Context& ctx, const FieldDescriptor& field) const override; + void InExternC(Context& ctx, const FieldDescriptor& field) const override; + void InThunkCc(Context& ctx, const FieldDescriptor& field) const override; }; class SingularMessage final : public AccessorGenerator { public: ~SingularMessage() override = default; - void InMsgImpl(Context field) const override; - void InExternC(Context field) const override; - void InThunkCc(Context field) const override; + void InMsgImpl(Context& ctx, const FieldDescriptor& field) const override; + void InExternC(Context& ctx, const FieldDescriptor& field) const override; + void InThunkCc(Context& ctx, const FieldDescriptor& field) const override; }; class RepeatedScalar final : public AccessorGenerator { public: ~RepeatedScalar() override = default; - void InMsgImpl(Context field) const override; - void InExternC(Context field) const override; - void InThunkCc(Context field) const override; + void InMsgImpl(Context& ctx, const FieldDescriptor& field) const override; + void InExternC(Context& ctx, const FieldDescriptor& field) const override; + void InThunkCc(Context& ctx, const FieldDescriptor& field) const override; }; class UnsupportedField final : public AccessorGenerator { public: explicit UnsupportedField(std::string reason) : reason_(std::move(reason)) {} ~UnsupportedField() override = default; - void InMsgImpl(Context field) const override; + void InMsgImpl(Context& ctx, const FieldDescriptor& field) const override; private: std::string reason_; @@ -109,9 +110,9 @@ class UnsupportedField final : public AccessorGenerator { class Map final : public AccessorGenerator { public: ~Map() override = default; - void InMsgImpl(Context field) const override; - void InExternC(Context field) const override; - void InThunkCc(Context field) const override; + void InMsgImpl(Context& ctx, const FieldDescriptor& field) const override; + void InExternC(Context& ctx, const FieldDescriptor& field) const override; + void InThunkCc(Context& ctx, const FieldDescriptor& field) const override; }; } // namespace rust diff --git a/src/google/protobuf/compiler/rust/accessors/accessors.cc b/src/google/protobuf/compiler/rust/accessors/accessors.cc index c25fbc207f..d4c93f00ba 100644 --- a/src/google/protobuf/compiler/rust/accessors/accessors.cc +++ b/src/google/protobuf/compiler/rust/accessors/accessors.cc @@ -23,17 +23,16 @@ namespace rust { namespace { std::unique_ptr AccessorGeneratorFor( - Context field) { - const FieldDescriptor& desc = field.desc(); + Context& ctx, const FieldDescriptor& field) { // TODO: We do not support [ctype=FOO] (used to set the field // type in C++ to cord or string_piece) in V0.6 API. - if (desc.options().has_ctype()) { + if (field.options().has_ctype()) { return std::make_unique( "fields with ctype not supported"); } - if (desc.is_map()) { - auto value_type = desc.message_type()->map_value()->type(); + if (field.is_map()) { + auto value_type = field.message_type()->map_value()->type(); switch (value_type) { case FieldDescriptor::TYPE_BYTES: case FieldDescriptor::TYPE_ENUM: @@ -46,7 +45,7 @@ std::unique_ptr AccessorGeneratorFor( } } - switch (desc.type()) { + switch (field.type()) { case FieldDescriptor::TYPE_INT32: case FieldDescriptor::TYPE_INT64: case FieldDescriptor::TYPE_FIXED32: @@ -60,22 +59,22 @@ std::unique_ptr AccessorGeneratorFor( case FieldDescriptor::TYPE_FLOAT: case FieldDescriptor::TYPE_DOUBLE: case FieldDescriptor::TYPE_BOOL: - if (desc.is_repeated()) { + if (field.is_repeated()) { return std::make_unique(); } return std::make_unique(); case FieldDescriptor::TYPE_BYTES: case FieldDescriptor::TYPE_STRING: - if (desc.is_repeated()) { + if (field.is_repeated()) { return std::make_unique("repeated str not supported"); } return std::make_unique(); case FieldDescriptor::TYPE_MESSAGE: - if (desc.is_repeated()) { + if (field.is_repeated()) { return std::make_unique("repeated msg not supported"); } - if (!field.generator_context().is_file_in_current_crate( - desc.message_type()->file())) { + if (!ctx.generator_context().is_file_in_current_crate( + *field.message_type()->file())) { return std::make_unique( "message fields that are imported from another proto_library" " (defined in a separate Rust crate) are not supported"); @@ -89,21 +88,21 @@ std::unique_ptr AccessorGeneratorFor( return std::make_unique("group not supported"); } - ABSL_LOG(FATAL) << "Unexpected field type: " << desc.type(); + ABSL_LOG(FATAL) << "Unexpected field type: " << field.type(); } } // namespace -void GenerateAccessorMsgImpl(Context field) { - AccessorGeneratorFor(field)->GenerateMsgImpl(field); +void GenerateAccessorMsgImpl(Context& ctx, const FieldDescriptor& field) { + AccessorGeneratorFor(ctx, field)->GenerateMsgImpl(ctx, field); } -void GenerateAccessorExternC(Context field) { - AccessorGeneratorFor(field)->GenerateExternC(field); +void GenerateAccessorExternC(Context& ctx, const FieldDescriptor& field) { + AccessorGeneratorFor(ctx, field)->GenerateExternC(ctx, field); } -void GenerateAccessorThunkCc(Context field) { - AccessorGeneratorFor(field)->GenerateThunkCc(field); +void GenerateAccessorThunkCc(Context& ctx, const FieldDescriptor& field) { + AccessorGeneratorFor(ctx, field)->GenerateThunkCc(ctx, field); } } // namespace rust diff --git a/src/google/protobuf/compiler/rust/accessors/accessors.h b/src/google/protobuf/compiler/rust/accessors/accessors.h index 05687e5477..801bff2057 100644 --- a/src/google/protobuf/compiler/rust/accessors/accessors.h +++ b/src/google/protobuf/compiler/rust/accessors/accessors.h @@ -16,9 +16,9 @@ namespace protobuf { namespace compiler { namespace rust { -void GenerateAccessorMsgImpl(Context field); -void GenerateAccessorExternC(Context field); -void GenerateAccessorThunkCc(Context field); +void GenerateAccessorMsgImpl(Context& ctx, const FieldDescriptor& field); +void GenerateAccessorExternC(Context& ctx, const FieldDescriptor& field); +void GenerateAccessorThunkCc(Context& ctx, const FieldDescriptor& field); } // namespace rust } // namespace compiler diff --git a/src/google/protobuf/compiler/rust/accessors/helpers.cc b/src/google/protobuf/compiler/rust/accessors/helpers.cc index b2d2a1bed8..ba41abd9b5 100644 --- a/src/google/protobuf/compiler/rust/accessors/helpers.cc +++ b/src/google/protobuf/compiler/rust/accessors/helpers.cc @@ -15,7 +15,6 @@ #include "absl/strings/escaping.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_format.h" -#include "google/protobuf/compiler/rust/context.h" #include "google/protobuf/descriptor.h" #include "google/protobuf/io/strtod.h" @@ -24,33 +23,32 @@ namespace protobuf { namespace compiler { namespace rust { -std::string DefaultValue(Context field) { - switch (field.desc().type()) { +std::string DefaultValue(const FieldDescriptor& field) { + switch (field.type()) { case FieldDescriptor::TYPE_DOUBLE: - if (std::isfinite(field.desc().default_value_double())) { - return absl::StrCat(io::SimpleDtoa(field.desc().default_value_double()), + if (std::isfinite(field.default_value_double())) { + return absl::StrCat(io::SimpleDtoa(field.default_value_double()), "f64"); - } else if (std::isnan(field.desc().default_value_double())) { + } else if (std::isnan(field.default_value_double())) { return std::string("f64::NAN"); - } else if (field.desc().default_value_double() == + } else if (field.default_value_double() == std::numeric_limits::infinity()) { return std::string("f64::INFINITY"); - } else if (field.desc().default_value_double() == + } else if (field.default_value_double() == -std::numeric_limits::infinity()) { return std::string("f64::NEG_INFINITY"); } else { ABSL_LOG(FATAL) << "unreachable"; } case FieldDescriptor::TYPE_FLOAT: - if (std::isfinite(field.desc().default_value_float())) { - return absl::StrCat(io::SimpleFtoa(field.desc().default_value_float()), - "f32"); - } else if (std::isnan(field.desc().default_value_float())) { + if (std::isfinite(field.default_value_float())) { + return absl::StrCat(io::SimpleFtoa(field.default_value_float()), "f32"); + } else if (std::isnan(field.default_value_float())) { return std::string("f32::NAN"); - } else if (field.desc().default_value_float() == + } else if (field.default_value_float() == std::numeric_limits::infinity()) { return std::string("f32::INFINITY"); - } else if (field.desc().default_value_float() == + } else if (field.default_value_float() == -std::numeric_limits::infinity()) { return std::string("f32::NEG_INFINITY"); } else { @@ -59,27 +57,27 @@ std::string DefaultValue(Context field) { case FieldDescriptor::TYPE_INT32: case FieldDescriptor::TYPE_SFIXED32: case FieldDescriptor::TYPE_SINT32: - return absl::StrFormat("%d", field.desc().default_value_int32()); + return absl::StrFormat("%d", field.default_value_int32()); case FieldDescriptor::TYPE_INT64: case FieldDescriptor::TYPE_SFIXED64: case FieldDescriptor::TYPE_SINT64: - return absl::StrFormat("%d", field.desc().default_value_int64()); + return absl::StrFormat("%d", field.default_value_int64()); case FieldDescriptor::TYPE_FIXED64: case FieldDescriptor::TYPE_UINT64: - return absl::StrFormat("%u", field.desc().default_value_uint64()); + return absl::StrFormat("%u", field.default_value_uint64()); case FieldDescriptor::TYPE_FIXED32: case FieldDescriptor::TYPE_UINT32: - return absl::StrFormat("%u", field.desc().default_value_uint32()); + return absl::StrFormat("%u", field.default_value_uint32()); case FieldDescriptor::TYPE_BOOL: - return absl::StrFormat("%v", field.desc().default_value_bool()); + return absl::StrFormat("%v", field.default_value_bool()); case FieldDescriptor::TYPE_STRING: case FieldDescriptor::TYPE_BYTES: - return absl::StrFormat( - "b\"%s\"", absl::CHexEscape(field.desc().default_value_string())); + return absl::StrFormat("b\"%s\"", + absl::CHexEscape(field.default_value_string())); case FieldDescriptor::TYPE_GROUP: case FieldDescriptor::TYPE_MESSAGE: case FieldDescriptor::TYPE_ENUM: - ABSL_LOG(FATAL) << "Unsupported field type: " << field.desc().type_name(); + ABSL_LOG(FATAL) << "Unsupported field type: " << field.type_name(); } ABSL_LOG(FATAL) << "unreachable"; } diff --git a/src/google/protobuf/compiler/rust/accessors/helpers.h b/src/google/protobuf/compiler/rust/accessors/helpers.h index 45c7f3afa2..ee2c429e3b 100644 --- a/src/google/protobuf/compiler/rust/accessors/helpers.h +++ b/src/google/protobuf/compiler/rust/accessors/helpers.h @@ -10,7 +10,6 @@ #include -#include "google/protobuf/compiler/rust/context.h" #include "google/protobuf/descriptor.h" namespace google { @@ -23,7 +22,7 @@ namespace rust { // Both strings and bytes are represented as a byte string literal, i.e. in the // format `b"default value here"`. It is the caller's responsibility to convert // the byte literal to an actual string, if needed. -std::string DefaultValue(Context field); +std::string DefaultValue(const FieldDescriptor& field); } // namespace rust } // namespace compiler diff --git a/src/google/protobuf/compiler/rust/accessors/map.cc b/src/google/protobuf/compiler/rust/accessors/map.cc index 0c2c1eb29d..89535d40e2 100644 --- a/src/google/protobuf/compiler/rust/accessors/map.cc +++ b/src/google/protobuf/compiler/rust/accessors/map.cc @@ -17,19 +17,19 @@ namespace protobuf { namespace compiler { namespace rust { -void Map::InMsgImpl(Context field) const { - auto& key_type = *field.desc().message_type()->map_key(); - auto& value_type = *field.desc().message_type()->map_value(); +void Map::InMsgImpl(Context& ctx, const FieldDescriptor& field) const { + auto& key_type = *field.message_type()->map_key(); + auto& value_type = *field.message_type()->map_value(); - field.Emit({{"field", field.desc().name()}, - {"Key", PrimitiveRsTypeName(key_type)}, - {"Value", PrimitiveRsTypeName(value_type)}, - {"getter_thunk", Thunk(field, "get")}, - {"getter_mut_thunk", Thunk(field, "get_mut")}, - {"getter", - [&] { - if (field.is_upb()) { - field.Emit({}, R"rs( + ctx.Emit({{"field", field.name()}, + {"Key", PrimitiveRsTypeName(key_type)}, + {"Value", PrimitiveRsTypeName(value_type)}, + {"getter_thunk", Thunk(ctx, field, "get")}, + {"getter_mut_thunk", Thunk(ctx, field, "get_mut")}, + {"getter", + [&] { + if (ctx.is_upb()) { + ctx.Emit({}, R"rs( pub fn r#$field$(&self) -> $pb$::View<'_, $pb$::Map<$Key$, $Value$>> { let inner = unsafe { @@ -44,8 +44,8 @@ void Map::InMsgImpl(Context field) const { }); $pb$::MapView::from_inner($pbi$::Private, inner) })rs"); - } else { - field.Emit({}, R"rs( + } else { + ctx.Emit({}, R"rs( pub fn r#$field$(&self) -> $pb$::View<'_, $pb$::Map<$Key$, $Value$>> { let inner = $pbr$::MapInner { @@ -55,12 +55,12 @@ void Map::InMsgImpl(Context field) const { }; $pb$::MapView::from_inner($pbi$::Private, inner) })rs"); - } - }}, - {"getter_mut", - [&] { - if (field.is_upb()) { - field.Emit({}, R"rs( + } + }}, + {"getter_mut", + [&] { + if (ctx.is_upb()) { + ctx.Emit({}, R"rs( pub fn r#$field$_mut(&mut self) -> $pb$::Mut<'_, $pb$::Map<$Key$, $Value$>> { let raw = unsafe { @@ -75,8 +75,8 @@ void Map::InMsgImpl(Context field) const { }; $pb$::MapMut::from_inner($pbi$::Private, inner) })rs"); - } else { - field.Emit({}, R"rs( + } else { + ctx.Emit({}, R"rs( pub fn r#$field$_mut(&mut self) -> $pb$::Mut<'_, $pb$::Map<$Key$, $Value$>> { let inner = $pbr$::MapInner { @@ -86,30 +86,30 @@ void Map::InMsgImpl(Context field) const { }; $pb$::MapMut::from_inner($pbi$::Private, inner) })rs"); - } - }}}, - R"rs( + } + }}}, + R"rs( $getter$ $getter_mut$ )rs"); } -void Map::InExternC(Context field) const { - field.Emit( +void Map::InExternC(Context& ctx, const FieldDescriptor& field) const { + ctx.Emit( { - {"getter_thunk", Thunk(field, "get")}, - {"getter_mut_thunk", Thunk(field, "get_mut")}, + {"getter_thunk", Thunk(ctx, field, "get")}, + {"getter_mut_thunk", Thunk(ctx, field, "get_mut")}, {"getter", [&] { - if (field.is_upb()) { - field.Emit({}, R"rs( + if (ctx.is_upb()) { + ctx.Emit({}, R"rs( fn $getter_thunk$(raw_msg: $pbi$::RawMessage) -> Option<$pbi$::RawMap>; fn $getter_mut_thunk$(raw_msg: $pbi$::RawMessage, arena: $pbi$::RawArena) -> $pbi$::RawMap; )rs"); } else { - field.Emit({}, R"rs( + ctx.Emit({}, R"rs( fn $getter_thunk$(msg: $pbi$::RawMessage) -> $pbi$::RawMap; fn $getter_mut_thunk$(msg: $pbi$::RawMessage,) -> $pbi$::RawMap; )rs"); @@ -121,20 +121,19 @@ void Map::InExternC(Context field) const { )rs"); } -void Map::InThunkCc(Context field) const { - field.Emit( - {{"field", cpp::FieldName(&field.desc())}, - {"Key", cpp::PrimitiveTypeName( - field.desc().message_type()->map_key()->cpp_type())}, - {"Value", cpp::PrimitiveTypeName( - field.desc().message_type()->map_value()->cpp_type())}, - {"QualifiedMsg", - cpp::QualifiedClassName(field.desc().containing_type())}, - {"getter_thunk", Thunk(field, "get")}, - {"getter_mut_thunk", Thunk(field, "get_mut")}, +void Map::InThunkCc(Context& ctx, const FieldDescriptor& field) const { + ctx.Emit( + {{"field", cpp::FieldName(&field)}, + {"Key", + cpp::PrimitiveTypeName(field.message_type()->map_key()->cpp_type())}, + {"Value", + cpp::PrimitiveTypeName(field.message_type()->map_value()->cpp_type())}, + {"QualifiedMsg", cpp::QualifiedClassName(field.containing_type())}, + {"getter_thunk", Thunk(ctx, field, "get")}, + {"getter_mut_thunk", Thunk(ctx, field, "get_mut")}, {"impls", [&] { - field.Emit( + ctx.Emit( R"cc( const void* $getter_thunk$($QualifiedMsg$& msg) { return &msg.$field$(); diff --git a/src/google/protobuf/compiler/rust/accessors/repeated_scalar.cc b/src/google/protobuf/compiler/rust/accessors/repeated_scalar.cc index 72f1d790b8..e0e1c6cffc 100644 --- a/src/google/protobuf/compiler/rust/accessors/repeated_scalar.cc +++ b/src/google/protobuf/compiler/rust/accessors/repeated_scalar.cc @@ -17,15 +17,16 @@ namespace protobuf { namespace compiler { namespace rust { -void RepeatedScalar::InMsgImpl(Context field) const { - field.Emit({{"field", field.desc().name()}, - {"Scalar", PrimitiveRsTypeName(field.desc())}, - {"getter_thunk", Thunk(field, "get")}, - {"getter_mut_thunk", Thunk(field, "get_mut")}, - {"getter", - [&] { - if (field.is_upb()) { - field.Emit({}, R"rs( +void RepeatedScalar::InMsgImpl(Context& ctx, + const FieldDescriptor& field) const { + ctx.Emit({{"field", field.name()}, + {"Scalar", PrimitiveRsTypeName(field)}, + {"getter_thunk", Thunk(ctx, field, "get")}, + {"getter_mut_thunk", Thunk(ctx, field, "get_mut")}, + {"getter", + [&] { + if (ctx.is_upb()) { + ctx.Emit({}, R"rs( pub fn r#$field$(&self) -> $pb$::RepeatedView<'_, $Scalar$> { unsafe { $getter_thunk$( @@ -40,8 +41,8 @@ void RepeatedScalar::InMsgImpl(Context field) const { ) } )rs"); - } else { - field.Emit({}, R"rs( + } else { + ctx.Emit({}, R"rs( pub fn r#$field$(&self) -> $pb$::RepeatedView<'_, $Scalar$> { unsafe { $pb$::RepeatedView::from_raw( @@ -51,13 +52,13 @@ void RepeatedScalar::InMsgImpl(Context field) const { } } )rs"); - } - }}, - {"clearer_thunk", Thunk(field, "clear")}, - {"field_mutator_getter", - [&] { - if (field.is_upb()) { - field.Emit({}, R"rs( + } + }}, + {"clearer_thunk", Thunk(ctx, field, "clear")}, + {"field_mutator_getter", + [&] { + if (ctx.is_upb()) { + ctx.Emit({}, R"rs( pub fn r#$field$_mut(&mut self) -> $pb$::RepeatedMut<'_, $Scalar$> { unsafe { $pb$::RepeatedMut::from_inner( @@ -75,8 +76,8 @@ void RepeatedScalar::InMsgImpl(Context field) const { } } )rs"); - } else { - field.Emit({}, R"rs( + } else { + ctx.Emit({}, R"rs( pub fn r#$field$_mut(&mut self) -> $pb$::RepeatedMut<'_, $Scalar$> { unsafe { $pb$::RepeatedMut::from_inner( @@ -89,22 +90,23 @@ void RepeatedScalar::InMsgImpl(Context field) const { } } )rs"); - } - }}}, - R"rs( + } + }}}, + R"rs( $getter$ $field_mutator_getter$ )rs"); } -void RepeatedScalar::InExternC(Context field) const { - field.Emit({{"Scalar", PrimitiveRsTypeName(field.desc())}, - {"getter_thunk", Thunk(field, "get")}, - {"getter_mut_thunk", Thunk(field, "get_mut")}, - {"getter", - [&] { - if (field.is_upb()) { - field.Emit(R"rs( +void RepeatedScalar::InExternC(Context& ctx, + const FieldDescriptor& field) const { + ctx.Emit({{"Scalar", PrimitiveRsTypeName(field)}, + {"getter_thunk", Thunk(ctx, field, "get")}, + {"getter_mut_thunk", Thunk(ctx, field, "get_mut")}, + {"getter", + [&] { + if (ctx.is_upb()) { + ctx.Emit(R"rs( fn $getter_mut_thunk$( raw_msg: $pbi$::RawMessage, size: *const usize, @@ -116,44 +118,44 @@ void RepeatedScalar::InExternC(Context field) const { size: *const usize, ) -> Option<$pbi$::RawRepeatedField>; )rs"); - } else { - field.Emit(R"rs( + } else { + ctx.Emit(R"rs( fn $getter_mut_thunk$(raw_msg: $pbi$::RawMessage) -> $pbi$::RawRepeatedField; fn $getter_thunk$(raw_msg: $pbi$::RawMessage) -> $pbi$::RawRepeatedField; )rs"); - } - }}, - {"clearer_thunk", Thunk(field, "clear")}}, - R"rs( + } + }}, + {"clearer_thunk", Thunk(ctx, field, "clear")}}, + R"rs( fn $clearer_thunk$(raw_msg: $pbi$::RawMessage); $getter$ )rs"); } -void RepeatedScalar::InThunkCc(Context field) const { - field.Emit({{"field", cpp::FieldName(&field.desc())}, - {"Scalar", cpp::PrimitiveTypeName(field.desc().cpp_type())}, - {"QualifiedMsg", - cpp::QualifiedClassName(field.desc().containing_type())}, - {"clearer_thunk", Thunk(field, "clear")}, - {"getter_thunk", Thunk(field, "get")}, - {"getter_mut_thunk", Thunk(field, "get_mut")}, - {"impls", - [&] { - field.Emit( - R"cc( - void $clearer_thunk$($QualifiedMsg$* msg) { - msg->clear_$field$(); - } - google::protobuf::RepeatedField<$Scalar$>* $getter_mut_thunk$($QualifiedMsg$* msg) { - return msg->mutable_$field$(); - } - const google::protobuf::RepeatedField<$Scalar$>& $getter_thunk$($QualifiedMsg$& msg) { - return msg.$field$(); - } - )cc"); - }}}, - "$impls$"); +void RepeatedScalar::InThunkCc(Context& ctx, + const FieldDescriptor& field) const { + ctx.Emit({{"field", cpp::FieldName(&field)}, + {"Scalar", cpp::PrimitiveTypeName(field.cpp_type())}, + {"QualifiedMsg", cpp::QualifiedClassName(field.containing_type())}, + {"clearer_thunk", Thunk(ctx, field, "clear")}, + {"getter_thunk", Thunk(ctx, field, "get")}, + {"getter_mut_thunk", Thunk(ctx, field, "get_mut")}, + {"impls", + [&] { + ctx.Emit( + R"cc( + void $clearer_thunk$($QualifiedMsg$* msg) { + msg->clear_$field$(); + } + google::protobuf::RepeatedField<$Scalar$>* $getter_mut_thunk$($QualifiedMsg$* msg) { + return msg->mutable_$field$(); + } + const google::protobuf::RepeatedField<$Scalar$>& $getter_thunk$($QualifiedMsg$& msg) { + return msg.$field$(); + } + )cc"); + }}}, + "$impls$"); } } // namespace rust diff --git a/src/google/protobuf/compiler/rust/accessors/singular_message.cc b/src/google/protobuf/compiler/rust/accessors/singular_message.cc index 210512a7eb..ca5dce246a 100644 --- a/src/google/protobuf/compiler/rust/accessors/singular_message.cc +++ b/src/google/protobuf/compiler/rust/accessors/singular_message.cc @@ -17,23 +17,23 @@ namespace protobuf { namespace compiler { namespace rust { -void SingularMessage::InMsgImpl(Context field) const { - Context d = field.WithDesc(field.desc().message_type()); +void SingularMessage::InMsgImpl(Context& ctx, + const FieldDescriptor& field) const { + auto& msg = *field.message_type(); + auto prefix = "crate::" + GetCrateRelativeQualifiedPath(ctx, msg); - auto prefix = "crate::" + GetCrateRelativeQualifiedPath(d); - - field.Emit( + ctx.Emit( { {"prefix", prefix}, - {"field", field.desc().name()}, - {"getter_thunk", Thunk(field, "get")}, - {"getter_mut_thunk", Thunk(field, "get_mut")}, - {"clearer_thunk", Thunk(field, "clear")}, + {"field", field.name()}, + {"getter_thunk", Thunk(ctx, field, "get")}, + {"getter_mut_thunk", Thunk(ctx, field, "get_mut")}, + {"clearer_thunk", Thunk(ctx, field, "clear")}, { "view_body", [&] { - if (field.is_upb()) { - field.Emit({}, R"rs( + if (ctx.is_upb()) { + ctx.Emit({}, R"rs( let submsg = unsafe { $getter_thunk$(self.inner.msg) }; // For upb, getters return null if the field is unset, so we need // to check for null and return the default instance manually. @@ -46,7 +46,7 @@ void SingularMessage::InMsgImpl(Context field) const { } )rs"); } else { - field.Emit({}, R"rs( + ctx.Emit({}, R"rs( // For C++ kernel, getters automatically return the // default_instance if the field is unset. let submsg = unsafe { $getter_thunk$(self.inner.msg) }; @@ -57,15 +57,15 @@ void SingularMessage::InMsgImpl(Context field) const { }, {"submessage_mut", [&] { - if (field.is_upb()) { - field.Emit({}, R"rs( + if (ctx.is_upb()) { + ctx.Emit({}, R"rs( let submsg = unsafe { $getter_mut_thunk$(self.inner.msg, self.inner.arena.raw()) }; $prefix$Mut::new($pbi$::Private, &mut self.inner, submsg) )rs"); } else { - field.Emit({}, R"rs( + ctx.Emit({}, R"rs( let submsg = unsafe { $getter_mut_thunk$(self.inner.msg) }; $prefix$Mut::new($pbi$::Private, &mut self.inner, submsg) )rs"); @@ -87,21 +87,22 @@ void SingularMessage::InMsgImpl(Context field) const { )rs"); } -void SingularMessage::InExternC(Context field) const { - field.Emit( +void SingularMessage::InExternC(Context& ctx, + const FieldDescriptor& field) const { + ctx.Emit( { - {"getter_thunk", Thunk(field, "get")}, - {"getter_mut_thunk", Thunk(field, "get_mut")}, - {"clearer_thunk", Thunk(field, "clear")}, + {"getter_thunk", Thunk(ctx, field, "get")}, + {"getter_mut_thunk", Thunk(ctx, field, "get_mut")}, + {"clearer_thunk", Thunk(ctx, field, "clear")}, {"getter_mut", [&] { - if (field.is_cpp()) { - field.Emit( + if (ctx.is_cpp()) { + ctx.Emit( R"rs( fn $getter_mut_thunk$(raw_msg: $pbi$::RawMessage) -> $pbi$::RawMessage;)rs"); } else { - field.Emit( + ctx.Emit( R"rs(fn $getter_mut_thunk$(raw_msg: $pbi$::RawMessage, arena: $pbi$::RawArena) -> $pbi$::RawMessage;)rs"); @@ -109,13 +110,13 @@ void SingularMessage::InExternC(Context field) const { }}, {"ReturnType", [&] { - if (field.is_cpp()) { + if (ctx.is_cpp()) { // guaranteed to have a nonnull submsg for the cpp kernel - field.Emit({}, "$pbi$::RawMessage;"); + ctx.Emit({}, "$pbi$::RawMessage;"); } else { // upb kernel may return NULL for a submsg, we can detect this // in terra rust if the option returned is None - field.Emit({}, "Option<$pbi$::RawMessage>;"); + ctx.Emit({}, "Option<$pbi$::RawMessage>;"); } }}, }, @@ -126,22 +127,22 @@ void SingularMessage::InExternC(Context field) const { )rs"); } -void SingularMessage::InThunkCc(Context field) const { - field.Emit({{"QualifiedMsg", - cpp::QualifiedClassName(field.desc().containing_type())}, - {"getter_thunk", Thunk(field, "get")}, - {"getter_mut_thunk", Thunk(field, "get_mut")}, - {"clearer_thunk", Thunk(field, "clear")}, - {"field", cpp::FieldName(&field.desc())}}, - R"cc( - const void* $getter_thunk$($QualifiedMsg$* msg) { - return static_cast(&msg->$field$()); - } - void* $getter_mut_thunk$($QualifiedMsg$* msg) { - return static_cast(msg->mutable_$field$()); - } - void $clearer_thunk$($QualifiedMsg$* msg) { msg->clear_$field$(); } - )cc"); +void SingularMessage::InThunkCc(Context& ctx, + const FieldDescriptor& field) const { + ctx.Emit({{"QualifiedMsg", cpp::QualifiedClassName(field.containing_type())}, + {"getter_thunk", Thunk(ctx, field, "get")}, + {"getter_mut_thunk", Thunk(ctx, field, "get_mut")}, + {"clearer_thunk", Thunk(ctx, field, "clear")}, + {"field", cpp::FieldName(&field)}}, + R"cc( + const void* $getter_thunk$($QualifiedMsg$* msg) { + return static_cast(&msg->$field$()); + } + void* $getter_mut_thunk$($QualifiedMsg$* msg) { + return static_cast(msg->mutable_$field$()); + } + void $clearer_thunk$($QualifiedMsg$* msg) { msg->clear_$field$(); } + )cc"); } } // namespace rust diff --git a/src/google/protobuf/compiler/rust/accessors/singular_scalar.cc b/src/google/protobuf/compiler/rust/accessors/singular_scalar.cc index bf4d2c2d6a..48aa4dbfca 100644 --- a/src/google/protobuf/compiler/rust/accessors/singular_scalar.cc +++ b/src/google/protobuf/compiler/rust/accessors/singular_scalar.cc @@ -18,16 +18,17 @@ namespace protobuf { namespace compiler { namespace rust { -void SingularScalar::InMsgImpl(Context field) const { - field.Emit( +void SingularScalar::InMsgImpl(Context& ctx, + const FieldDescriptor& field) const { + ctx.Emit( { - {"field", field.desc().name()}, - {"Scalar", PrimitiveRsTypeName(field.desc())}, - {"hazzer_thunk", Thunk(field, "has")}, + {"field", field.name()}, + {"Scalar", PrimitiveRsTypeName(field)}, + {"hazzer_thunk", Thunk(ctx, field, "has")}, {"default_value", DefaultValue(field)}, {"getter", [&] { - field.Emit({}, R"rs( + ctx.Emit({}, R"rs( pub fn r#$field$(&self) -> $Scalar$ { unsafe { $getter_thunk$(self.inner.msg) } } @@ -35,9 +36,9 @@ void SingularScalar::InMsgImpl(Context field) const { }}, {"getter_opt", [&] { - if (!field.desc().is_optional()) return; - if (!field.desc().has_presence()) return; - field.Emit({}, R"rs( + if (!field.is_optional()) return; + if (!field.has_presence()) return; + ctx.Emit({}, R"rs( pub fn r#$field$_opt(&self) -> $pb$::Optional<$Scalar$> { if !unsafe { $hazzer_thunk$(self.inner.msg) } { return $pb$::Optional::Unset($default_value$); @@ -47,13 +48,13 @@ void SingularScalar::InMsgImpl(Context field) const { } )rs"); }}, - {"getter_thunk", Thunk(field, "get")}, - {"setter_thunk", Thunk(field, "set")}, - {"clearer_thunk", Thunk(field, "clear")}, + {"getter_thunk", Thunk(ctx, field, "get")}, + {"setter_thunk", Thunk(ctx, field, "set")}, + {"clearer_thunk", Thunk(ctx, field, "clear")}, {"field_mutator_getter", [&] { - if (field.desc().has_presence()) { - field.Emit({}, R"rs( + if (field.has_presence()) { + ctx.Emit({}, R"rs( pub fn r#$field$_mut(&mut self) -> $pb$::FieldEntry<'_, $Scalar$> { static VTABLE: $pbi$::PrimitiveOptionalMutVTable<$Scalar$> = $pbi$::PrimitiveOptionalMutVTable::new( @@ -76,7 +77,7 @@ void SingularScalar::InMsgImpl(Context field) const { } )rs"); } else { - field.Emit({}, R"rs( + ctx.Emit({}, R"rs( pub fn r#$field$_mut(&mut self) -> $pb$::Mut<'_, $Scalar$> { static VTABLE: $pbi$::PrimitiveVTable<$Scalar$> = $pbi$::PrimitiveVTable::new( @@ -114,56 +115,57 @@ void SingularScalar::InMsgImpl(Context field) const { )rs"); } -void SingularScalar::InExternC(Context field) const { - field.Emit({{"Scalar", PrimitiveRsTypeName(field.desc())}, - {"hazzer_thunk", Thunk(field, "has")}, - {"getter_thunk", Thunk(field, "get")}, - {"setter_thunk", Thunk(field, "set")}, - {"clearer_thunk", Thunk(field, "clear")}, - {"hazzer_and_clearer", - [&] { - if (field.desc().has_presence()) { - field.Emit( - R"rs( +void SingularScalar::InExternC(Context& ctx, + const FieldDescriptor& field) const { + ctx.Emit({{"Scalar", PrimitiveRsTypeName(field)}, + {"hazzer_thunk", Thunk(ctx, field, "has")}, + {"getter_thunk", Thunk(ctx, field, "get")}, + {"setter_thunk", Thunk(ctx, field, "set")}, + {"clearer_thunk", Thunk(ctx, field, "clear")}, + {"hazzer_and_clearer", + [&] { + if (field.has_presence()) { + ctx.Emit( + R"rs( fn $hazzer_thunk$(raw_msg: $pbi$::RawMessage) -> bool; fn $clearer_thunk$(raw_msg: $pbi$::RawMessage); )rs"); - } - }}}, - R"rs( + } + }}}, + R"rs( $hazzer_and_clearer$ fn $getter_thunk$(raw_msg: $pbi$::RawMessage) -> $Scalar$; fn $setter_thunk$(raw_msg: $pbi$::RawMessage, val: $Scalar$); )rs"); } -void SingularScalar::InThunkCc(Context field) const { - field.Emit({{"field", cpp::FieldName(&field.desc())}, - {"Scalar", cpp::PrimitiveTypeName(field.desc().cpp_type())}, - {"QualifiedMsg", - cpp::QualifiedClassName(field.desc().containing_type())}, - {"hazzer_thunk", Thunk(field, "has")}, - {"getter_thunk", Thunk(field, "get")}, - {"setter_thunk", Thunk(field, "set")}, - {"clearer_thunk", Thunk(field, "clear")}, - {"hazzer_and_clearer", - [&] { - if (field.desc().has_presence()) { - field.Emit(R"cc( - bool $hazzer_thunk$($QualifiedMsg$* msg) { - return msg->has_$field$(); - } - void $clearer_thunk$($QualifiedMsg$* msg) { msg->clear_$field$(); } - )cc"); - } - }}}, - R"cc( - $hazzer_and_clearer$; - $Scalar$ $getter_thunk$($QualifiedMsg$* msg) { return msg->$field$(); } - void $setter_thunk$($QualifiedMsg$* msg, $Scalar$ val) { - msg->set_$field$(val); +void SingularScalar::InThunkCc(Context& ctx, + const FieldDescriptor& field) const { + ctx.Emit({{"field", cpp::FieldName(&field)}, + {"Scalar", cpp::PrimitiveTypeName(field.cpp_type())}, + {"QualifiedMsg", cpp::QualifiedClassName(field.containing_type())}, + {"hazzer_thunk", Thunk(ctx, field, "has")}, + {"getter_thunk", Thunk(ctx, field, "get")}, + {"setter_thunk", Thunk(ctx, field, "set")}, + {"clearer_thunk", Thunk(ctx, field, "clear")}, + {"hazzer_and_clearer", + [&] { + if (field.has_presence()) { + ctx.Emit(R"cc( + bool $hazzer_thunk$($QualifiedMsg$* msg) { + return msg->has_$field$(); + } + void $clearer_thunk$($QualifiedMsg$* msg) { msg->clear_$field$(); } + )cc"); } - )cc"); + }}}, + R"cc( + $hazzer_and_clearer$; + $Scalar$ $getter_thunk$($QualifiedMsg$* msg) { return msg->$field$(); } + void $setter_thunk$($QualifiedMsg$* msg, $Scalar$ val) { + msg->set_$field$(val); + } + )cc"); } } // namespace rust diff --git a/src/google/protobuf/compiler/rust/accessors/singular_string.cc b/src/google/protobuf/compiler/rust/accessors/singular_string.cc index f3ae1e2f49..c70380c67c 100644 --- a/src/google/protobuf/compiler/rust/accessors/singular_string.cc +++ b/src/google/protobuf/compiler/rust/accessors/singular_string.cc @@ -20,24 +20,25 @@ namespace protobuf { namespace compiler { namespace rust { -void SingularString::InMsgImpl(Context field) const { - std::string hazzer_thunk = Thunk(field, "has"); - std::string getter_thunk = Thunk(field, "get"); - std::string setter_thunk = Thunk(field, "set"); - std::string proxied_type = PrimitiveRsTypeName(field.desc()); +void SingularString::InMsgImpl(Context& ctx, + const FieldDescriptor& field) const { + std::string hazzer_thunk = Thunk(ctx, field, "has"); + std::string getter_thunk = Thunk(ctx, field, "get"); + std::string setter_thunk = Thunk(ctx, field, "set"); + std::string proxied_type = PrimitiveRsTypeName(field); auto transform_view = [&] { - if (field.desc().type() == FieldDescriptor::TYPE_STRING) { - field.Emit(R"rs( + if (field.type() == FieldDescriptor::TYPE_STRING) { + ctx.Emit(R"rs( // SAFETY: The runtime doesn't require ProtoStr to be UTF-8. unsafe { $pb$::ProtoStr::from_utf8_unchecked(view) } )rs"); } else { - field.Emit("view"); + ctx.Emit("view"); } }; - field.Emit( + ctx.Emit( { - {"field", field.desc().name()}, + {"field", field.name()}, {"hazzer_thunk", hazzer_thunk}, {"getter_thunk", getter_thunk}, {"setter_thunk", setter_thunk}, @@ -45,12 +46,12 @@ void SingularString::InMsgImpl(Context field) const { {"transform_view", transform_view}, {"field_optional_getter", [&] { - if (!field.desc().is_optional()) return; - if (!field.desc().has_presence()) return; - field.Emit({{"hazzer_thunk", hazzer_thunk}, - {"getter_thunk", getter_thunk}, - {"transform_view", transform_view}}, - R"rs( + if (!field.is_optional()) return; + if (!field.has_presence()) return; + ctx.Emit({{"hazzer_thunk", hazzer_thunk}, + {"getter_thunk", getter_thunk}, + {"transform_view", transform_view}}, + R"rs( pub fn $field$_opt(&self) -> $pb$::Optional<&$proxied_type$> { let view = unsafe { $getter_thunk$(self.inner.msg).as_ref() }; $pb$::Optional::new( @@ -62,30 +63,29 @@ void SingularString::InMsgImpl(Context field) const { }}, {"field_mutator_getter", [&] { - if (field.desc().has_presence()) { - field.Emit( + if (field.has_presence()) { + ctx.Emit( { - {"field", field.desc().name()}, + {"field", field.name()}, {"proxied_type", proxied_type}, {"default_val", DefaultValue(field)}, {"view_type", proxied_type}, {"transform_field_entry", [&] { - if (field.desc().type() == - FieldDescriptor::TYPE_STRING) { - field.Emit(R"rs( + if (field.type() == FieldDescriptor::TYPE_STRING) { + ctx.Emit(R"rs( $pb$::ProtoStrMut::field_entry_from_bytes( $pbi$::Private, out ) )rs"); } else { - field.Emit("out"); + ctx.Emit("out"); } }}, {"hazzer_thunk", hazzer_thunk}, {"getter_thunk", getter_thunk}, {"setter_thunk", setter_thunk}, - {"clearer_thunk", Thunk(field, "clear")}, + {"clearer_thunk", Thunk(ctx, field, "clear")}, }, R"rs( pub fn $field$_mut(&mut self) -> $pb$::FieldEntry<'_, $proxied_type$> { @@ -112,11 +112,11 @@ void SingularString::InMsgImpl(Context field) const { } )rs"); } else { - field.Emit({{"field", field.desc().name()}, - {"proxied_type", proxied_type}, - {"getter_thunk", getter_thunk}, - {"setter_thunk", setter_thunk}}, - R"rs( + ctx.Emit({{"field", field.name()}, + {"proxied_type", proxied_type}, + {"getter_thunk", getter_thunk}, + {"setter_thunk", setter_thunk}}, + R"rs( pub fn $field$_mut(&mut self) -> $pb$::Mut<'_, $proxied_type$> { static VTABLE: $pbi$::BytesMutVTable = unsafe { $pbi$::BytesMutVTable::new( @@ -152,20 +152,21 @@ void SingularString::InMsgImpl(Context field) const { )rs"); } -void SingularString::InExternC(Context field) const { - field.Emit({{"hazzer_thunk", Thunk(field, "has")}, - {"getter_thunk", Thunk(field, "get")}, - {"setter_thunk", Thunk(field, "set")}, - {"clearer_thunk", Thunk(field, "clear")}, - {"hazzer", - [&] { - if (field.desc().has_presence()) { - field.Emit(R"rs( +void SingularString::InExternC(Context& ctx, + const FieldDescriptor& field) const { + ctx.Emit({{"hazzer_thunk", Thunk(ctx, field, "has")}, + {"getter_thunk", Thunk(ctx, field, "get")}, + {"setter_thunk", Thunk(ctx, field, "set")}, + {"clearer_thunk", Thunk(ctx, field, "clear")}, + {"hazzer", + [&] { + if (field.has_presence()) { + ctx.Emit(R"rs( fn $hazzer_thunk$(raw_msg: $pbi$::RawMessage) -> bool; )rs"); - } - }}}, - R"rs( + } + }}}, + R"rs( $hazzer$ fn $getter_thunk$(raw_msg: $pbi$::RawMessage) -> $pbi$::PtrAndLen; fn $setter_thunk$(raw_msg: $pbi$::RawMessage, val: $pbi$::PtrAndLen); @@ -173,35 +174,35 @@ void SingularString::InExternC(Context field) const { )rs"); } -void SingularString::InThunkCc(Context field) const { - field.Emit({{"field", cpp::FieldName(&field.desc())}, - {"QualifiedMsg", - cpp::QualifiedClassName(field.desc().containing_type())}, - {"hazzer_thunk", Thunk(field, "has")}, - {"getter_thunk", Thunk(field, "get")}, - {"setter_thunk", Thunk(field, "set")}, - {"clearer_thunk", Thunk(field, "clear")}, - {"hazzer", - [&] { - if (field.desc().has_presence()) { - field.Emit(R"cc( - bool $hazzer_thunk$($QualifiedMsg$* msg) { - return msg->has_$field$(); - } - void $clearer_thunk$($QualifiedMsg$* msg) { msg->clear_$field$(); } - )cc"); - } - }}}, - R"cc( - $hazzer$; - ::google::protobuf::rust_internal::PtrAndLen $getter_thunk$($QualifiedMsg$* msg) { - absl::string_view val = msg->$field$(); - return ::google::protobuf::rust_internal::PtrAndLen(val.data(), val.size()); - } - void $setter_thunk$($QualifiedMsg$* msg, ::google::protobuf::rust_internal::PtrAndLen s) { - msg->set_$field$(absl::string_view(s.ptr, s.len)); +void SingularString::InThunkCc(Context& ctx, + const FieldDescriptor& field) const { + ctx.Emit({{"field", cpp::FieldName(&field)}, + {"QualifiedMsg", cpp::QualifiedClassName(field.containing_type())}, + {"hazzer_thunk", Thunk(ctx, field, "has")}, + {"getter_thunk", Thunk(ctx, field, "get")}, + {"setter_thunk", Thunk(ctx, field, "set")}, + {"clearer_thunk", Thunk(ctx, field, "clear")}, + {"hazzer", + [&] { + if (field.has_presence()) { + ctx.Emit(R"cc( + bool $hazzer_thunk$($QualifiedMsg$* msg) { + return msg->has_$field$(); + } + void $clearer_thunk$($QualifiedMsg$* msg) { msg->clear_$field$(); } + )cc"); } - )cc"); + }}}, + R"cc( + $hazzer$; + ::google::protobuf::rust_internal::PtrAndLen $getter_thunk$($QualifiedMsg$* msg) { + absl::string_view val = msg->$field$(); + return ::google::protobuf::rust_internal::PtrAndLen(val.data(), val.size()); + } + void $setter_thunk$($QualifiedMsg$* msg, ::google::protobuf::rust_internal::PtrAndLen s) { + msg->set_$field$(absl::string_view(s.ptr, s.len)); + } + )cc"); } } // namespace rust diff --git a/src/google/protobuf/compiler/rust/accessors/unsupported_field.cc b/src/google/protobuf/compiler/rust/accessors/unsupported_field.cc index e6bce02171..591298a2bc 100644 --- a/src/google/protobuf/compiler/rust/accessors/unsupported_field.cc +++ b/src/google/protobuf/compiler/rust/accessors/unsupported_field.cc @@ -15,11 +15,12 @@ namespace protobuf { namespace compiler { namespace rust { -void UnsupportedField::InMsgImpl(Context field) const { - field.Emit({{"reason", reason_}}, R"rs( +void UnsupportedField::InMsgImpl(Context& ctx, + const FieldDescriptor& field) const { + ctx.Emit({{"reason", reason_}}, R"rs( // Unsupported! :( Reason: $reason$ )rs"); - field.printer().PrintRaw("\n"); + ctx.printer().PrintRaw("\n"); } } // namespace rust diff --git a/src/google/protobuf/compiler/rust/context.cc b/src/google/protobuf/compiler/rust/context.cc index 68d4c9fb2a..f64e88aab0 100644 --- a/src/google/protobuf/compiler/rust/context.cc +++ b/src/google/protobuf/compiler/rust/context.cc @@ -68,13 +68,12 @@ absl::StatusOr Options::Parse(absl::string_view param) { return opts; } -bool IsInCurrentlyGeneratingCrate(Context file) { - return file.generator_context().is_file_in_current_crate(&file.desc()); +bool IsInCurrentlyGeneratingCrate(Context& ctx, const FileDescriptor& file) { + return ctx.generator_context().is_file_in_current_crate(file); } -bool IsInCurrentlyGeneratingCrate(Context message) { - return message.generator_context().is_file_in_current_crate( - message.desc().file()); +bool IsInCurrentlyGeneratingCrate(Context& ctx, const Descriptor& message) { + return IsInCurrentlyGeneratingCrate(ctx, *message.file()); } } // namespace rust diff --git a/src/google/protobuf/compiler/rust/context.h b/src/google/protobuf/compiler/rust/context.h index affe8aee1b..1539e2e8f6 100644 --- a/src/google/protobuf/compiler/rust/context.h +++ b/src/google/protobuf/compiler/rust/context.h @@ -53,14 +53,14 @@ class RustGeneratorContext { const std::vector* files_in_current_crate) : files_in_current_crate_(*files_in_current_crate) {} - const FileDescriptor* primary_file() const { - return files_in_current_crate_.front(); + const FileDescriptor& primary_file() const { + return *files_in_current_crate_.front(); } - bool is_file_in_current_crate(const FileDescriptor* f) const { + bool is_file_in_current_crate(const FileDescriptor& f) const { return std::find(files_in_current_crate_.begin(), files_in_current_crate_.end(), - f) != files_in_current_crate_.end(); + &f) != files_in_current_crate_.end(); } private: @@ -68,26 +68,20 @@ class RustGeneratorContext { }; // A context for generating a particular kind of definition. -// This type acts as an options struct (as in go/totw/173) for most of the -// generator. -// -// `Descriptor` is the type of a descriptor.h class relevant for the current -// context. -template class Context { public: - Context(const Options* opts, const Descriptor* desc, + Context(const Options* opts, const RustGeneratorContext* rust_generator_context, io::Printer* printer) : opts_(opts), - desc_(desc), rust_generator_context_(rust_generator_context), printer_(printer) {} - Context(const Context&) = default; - Context& operator=(const Context&) = default; + Context(const Context&) = delete; + Context& operator=(const Context&) = delete; + Context(Context&&) = default; + Context& operator=(Context&&) = default; - const Descriptor& desc() const { return *desc_; } const Options& opts() const { return *opts_; } const RustGeneratorContext& generator_context() const { return *rust_generator_context_; @@ -99,19 +93,8 @@ class Context { // NOTE: prefer ctx.Emit() over ctx.printer().Emit(); io::Printer& printer() const { return *printer_; } - // Creates a new context over a different descriptor. - template - Context WithDesc(const D& desc) const { - return Context(opts_, &desc, rust_generator_context_, printer_); - } - - template - Context WithDesc(const D* desc) const { - return Context(opts_, desc, rust_generator_context_, printer_); - } - Context WithPrinter(io::Printer* printer) const { - return Context(opts_, desc_, rust_generator_context_, printer); + return Context(opts_, rust_generator_context_, printer); } // Forwards to Emit(), which will likely be called all the time. @@ -128,13 +111,12 @@ class Context { private: const Options* opts_; - const Descriptor* desc_; const RustGeneratorContext* rust_generator_context_; io::Printer* printer_; }; -bool IsInCurrentlyGeneratingCrate(Context file); -bool IsInCurrentlyGeneratingCrate(Context message); +bool IsInCurrentlyGeneratingCrate(Context& ctx, const FileDescriptor& file); +bool IsInCurrentlyGeneratingCrate(Context& ctx, const Descriptor& message); } // namespace rust } // namespace compiler diff --git a/src/google/protobuf/compiler/rust/generator.cc b/src/google/protobuf/compiler/rust/generator.cc index 3465672232..04d5ff6f32 100644 --- a/src/google/protobuf/compiler/rust/generator.cc +++ b/src/google/protobuf/compiler/rust/generator.cc @@ -48,12 +48,11 @@ namespace { // pub mod submodule { // pub mod separator { // ``` -void EmitOpeningOfPackageModules(absl::string_view pkg, - Context file) { +void EmitOpeningOfPackageModules(Context& ctx, absl::string_view pkg) { if (pkg.empty()) return; for (absl::string_view segment : absl::StrSplit(pkg, '.')) { - file.Emit({{"segment", segment}}, - R"rs( + ctx.Emit({{"segment", segment}}, + R"rs( pub mod $segment$ { )rs"); } @@ -70,14 +69,13 @@ void EmitOpeningOfPackageModules(absl::string_view pkg, // } // mod uses // } // mod package // ``` -void EmitClosingOfPackageModules(absl::string_view pkg, - Context file) { +void EmitClosingOfPackageModules(Context& ctx, absl::string_view pkg) { if (pkg.empty()) return; std::vector segments = absl::StrSplit(pkg, '.'); absl::c_reverse(segments); for (absl::string_view segment : segments) { - file.Emit({{"segment", segment}}, R"rs( + ctx.Emit({{"segment", segment}}, R"rs( } // mod $segment$ )rs"); } @@ -87,14 +85,13 @@ void EmitClosingOfPackageModules(absl::string_view pkg, // `non_primary_src` into the `primary_file`. // // `non_primary_src` has to be a non-primary src of the current `proto_library`. -void EmitPubUseOfOwnMessages(Context& primary_file, - const Context& non_primary_src) { - for (int i = 0; i < non_primary_src.desc().message_type_count(); ++i) { - auto msg = primary_file.WithDesc(non_primary_src.desc().message_type(i)); - auto mod = RustInternalModuleName(non_primary_src); - auto name = msg.desc().name(); - primary_file.Emit({{"mod", mod}, {"Msg", name}}, - R"rs( +void EmitPubUseOfOwnMessages(Context& ctx, const FileDescriptor& primary_file, + const FileDescriptor& non_primary_src) { + for (int i = 0; i < non_primary_src.message_type_count(); ++i) { + auto& msg = *non_primary_src.message_type(i); + auto mod = RustInternalModuleName(ctx, non_primary_src); + ctx.Emit({{"mod", mod}, {"Msg", msg.name()}}, + R"rs( pub use crate::$mod$::$Msg$; // TODO Address use for imported crates pub use crate::$mod$::$Msg$View; @@ -109,14 +106,15 @@ void EmitPubUseOfOwnMessages(Context& primary_file, // // `dep` is a primary src of a dependency of the current `proto_library`. // TODO: Add support for public import of non-primary srcs of deps. -void EmitPubUseForImportedMessages(Context& primary_file, - const Context& dep) { - std::string crate_name = GetCrateName(dep); - for (int i = 0; i < dep.desc().message_type_count(); ++i) { - auto msg = primary_file.WithDesc(dep.desc().message_type(i)); - auto path = GetCrateRelativeQualifiedPath(msg); - primary_file.Emit({{"crate", crate_name}, {"pkg::Msg", path}}, - R"rs( +void EmitPubUseForImportedMessages(Context& ctx, + const FileDescriptor& primary_file, + const FileDescriptor& dep) { + std::string crate_name = GetCrateName(ctx, dep); + for (int i = 0; i < dep.message_type_count(); ++i) { + auto& msg = *dep.message_type(i); + auto path = GetCrateRelativeQualifiedPath(ctx, msg); + ctx.Emit({{"crate", crate_name}, {"pkg::Msg", path}}, + R"rs( pub use $crate$::$pkg::Msg$; pub use $crate$::$pkg::Msg$View; )rs"); @@ -124,9 +122,9 @@ void EmitPubUseForImportedMessages(Context& primary_file, } // Emits all public imports of the current file -void EmitPublicImports(Context& primary_file) { - for (int i = 0; i < primary_file.desc().public_dependency_count(); ++i) { - auto dep_file = primary_file.desc().public_dependency(i); +void EmitPublicImports(Context& ctx, const FileDescriptor& primary_file) { + for (int i = 0; i < primary_file.public_dependency_count(); ++i) { + auto& dep_file = *primary_file.public_dependency(i); // If the publicly imported file is a src of the current `proto_library` // we don't need to emit `pub use` here, we already do it for all srcs in // RustGenerator::Generate. In other words, all srcs are implicitly publicly @@ -134,30 +132,29 @@ void EmitPublicImports(Context& primary_file) { // TODO: Handle the case where a non-primary src with the same // declared package as the primary src publicly imports a file that the // primary doesn't. - auto dep = primary_file.WithDesc(dep_file); - if (IsInCurrentlyGeneratingCrate(dep)) { + if (IsInCurrentlyGeneratingCrate(ctx, dep_file)) { return; } - EmitPubUseForImportedMessages(primary_file, dep); + EmitPubUseForImportedMessages(ctx, primary_file, dep_file); } } // Emits submodule declarations so `rustc` can find non primary sources from the // primary file. void DeclareSubmodulesForNonPrimarySrcs( - Context& primary_file, - absl::Span> non_primary_srcs) { - std::string primary_file_path = GetRsFile(primary_file); + Context& ctx, const FileDescriptor& primary_file, + absl::Span non_primary_srcs) { + std::string primary_file_path = GetRsFile(ctx, primary_file); RelativePath primary_relpath(primary_file_path); - for (const auto& non_primary_src : non_primary_srcs) { - std::string non_primary_file_path = GetRsFile(non_primary_src); + for (const FileDescriptor* non_primary_src : non_primary_srcs) { + std::string non_primary_file_path = GetRsFile(ctx, *non_primary_src); std::string relative_mod_path = primary_relpath.Relative(RelativePath(non_primary_file_path)); - primary_file.Emit({{"file_path", relative_mod_path}, - {"foo", primary_file_path}, - {"bar", non_primary_file_path}, - {"mod_name", RustInternalModuleName(non_primary_src)}}, - R"rs( + ctx.Emit({{"file_path", relative_mod_path}, + {"foo", primary_file_path}, + {"bar", non_primary_file_path}, + {"mod_name", RustInternalModuleName(ctx, *non_primary_src)}}, + R"rs( #[path="$file_path$"] pub mod $mod_name$; )rs"); @@ -169,33 +166,32 @@ void DeclareSubmodulesForNonPrimarySrcs( // // Returns the non-primary sources that should be reexported from the package of // the primary file. -std::vector*> ReexportMessagesFromSubmodules( - Context& primary_file, - absl::Span> non_primary_srcs) { - absl::btree_map*>> +std::vector ReexportMessagesFromSubmodules( + Context& ctx, const FileDescriptor& primary_file, + absl::Span non_primary_srcs) { + absl::btree_map> packages; - for (const Context& ctx : non_primary_srcs) { - packages[ctx.desc().package()].push_back(&ctx); + for (const FileDescriptor* file : non_primary_srcs) { + packages[file->package()].push_back(file); } for (const auto& pair : packages) { // We will deal with messages for the package of the primary file later. auto fds = pair.second; - absl::string_view package = fds[0]->desc().package(); - if (package == primary_file.desc().package()) continue; + absl::string_view package = fds[0]->package(); + if (package == primary_file.package()) continue; - EmitOpeningOfPackageModules(package, primary_file); - for (const Context* c : fds) { - EmitPubUseOfOwnMessages(primary_file, *c); + EmitOpeningOfPackageModules(ctx, package); + for (const FileDescriptor* c : fds) { + EmitPubUseOfOwnMessages(ctx, primary_file, *c); } - EmitClosingOfPackageModules(package, primary_file); + EmitClosingOfPackageModules(ctx, package); } - return packages[primary_file.desc().package()]; + return packages[primary_file.package()]; } } // namespace -bool RustGenerator::Generate(const FileDescriptor* file_desc, +bool RustGenerator::Generate(const FileDescriptor* file, const std::string& parameter, GeneratorContext* generator_context, std::string* error) const { @@ -210,15 +206,15 @@ bool RustGenerator::Generate(const FileDescriptor* file_desc, RustGeneratorContext rust_generator_context(&files_in_current_crate); - Context file(&*opts, file_desc, &rust_generator_context, - nullptr); + Context ctx_without_printer(&*opts, &rust_generator_context, nullptr); - auto outfile = absl::WrapUnique(generator_context->Open(GetRsFile(file))); + auto outfile = absl::WrapUnique( + generator_context->Open(GetRsFile(ctx_without_printer, *file))); io::Printer printer(outfile.get()); - file = file.WithPrinter(&printer); + Context ctx = ctx_without_printer.WithPrinter(&printer); // Convenience shorthands for common symbols. - auto v = file.printer().WithVars({ + auto v = ctx.printer().WithVars({ {"std", "::__std"}, {"pb", "::__pb"}, {"pbi", "::__pb::__internal"}, @@ -227,67 +223,66 @@ bool RustGenerator::Generate(const FileDescriptor* file_desc, {"Phantom", "::__std::marker::PhantomData"}, }); - file.Emit({{"kernel", KernelRsName(file.opts().kernel)}}, R"rs( + ctx.Emit({{"kernel", KernelRsName(ctx.opts().kernel)}}, R"rs( extern crate protobuf_$kernel$ as __pb; extern crate std as __std; )rs"); - std::vector> file_contexts; + std::vector file_contexts; for (const FileDescriptor* f : files_in_current_crate) { - file_contexts.push_back(file.WithDesc(*f)); + file_contexts.push_back(f); } // Generating the primary file? - if (file_desc == rust_generator_context.primary_file()) { + if (file == &rust_generator_context.primary_file()) { auto non_primary_srcs = absl::MakeConstSpan(file_contexts).subspan(1); - DeclareSubmodulesForNonPrimarySrcs(file, non_primary_srcs); + DeclareSubmodulesForNonPrimarySrcs(ctx, *file, non_primary_srcs); - std::vector*> - non_primary_srcs_in_primary_package = - ReexportMessagesFromSubmodules(file, non_primary_srcs); + std::vector non_primary_srcs_in_primary_package = + ReexportMessagesFromSubmodules(ctx, *file, non_primary_srcs); - EmitOpeningOfPackageModules(file.desc().package(), file); + EmitOpeningOfPackageModules(ctx, file->package()); - for (const Context* non_primary_file : + for (const FileDescriptor* non_primary_file : non_primary_srcs_in_primary_package) { - EmitPubUseOfOwnMessages(file, *non_primary_file); + EmitPubUseOfOwnMessages(ctx, *file, *non_primary_file); } } - EmitPublicImports(file); + EmitPublicImports(ctx, *file); std::unique_ptr thunks_cc; std::unique_ptr thunks_printer; - if (file.is_cpp()) { - thunks_cc.reset(generator_context->Open(GetThunkCcFile(file))); + if (ctx.is_cpp()) { + thunks_cc.reset(generator_context->Open(GetThunkCcFile(ctx, *file))); thunks_printer = std::make_unique(thunks_cc.get()); - thunks_printer->Emit({{"proto_h", GetHeaderFile(file)}}, + thunks_printer->Emit({{"proto_h", GetHeaderFile(ctx, *file)}}, R"cc( #include "$proto_h$" #include "google/protobuf/rust/cpp_kernel/cpp_api.h" )cc"); } - for (int i = 0; i < file.desc().message_type_count(); ++i) { - auto msg = file.WithDesc(file.desc().message_type(i)); + for (int i = 0; i < file->message_type_count(); ++i) { + auto& msg = *file->message_type(i); - GenerateRs(msg); - msg.printer().PrintRaw("\n"); + GenerateRs(ctx, msg); + ctx.printer().PrintRaw("\n"); - if (file.is_cpp()) { - auto thunks_msg = msg.WithPrinter(thunks_printer.get()); + if (ctx.is_cpp()) { + auto thunks_ctx = ctx.WithPrinter(thunks_printer.get()); - thunks_msg.Emit({{"Msg", msg.desc().full_name()}}, R"cc( + thunks_ctx.Emit({{"Msg", msg.full_name()}}, R"cc( // $Msg$ )cc"); - GenerateThunksCc(thunks_msg); - thunks_msg.printer().PrintRaw("\n"); + GenerateThunksCc(thunks_ctx, msg); + thunks_ctx.printer().PrintRaw("\n"); } } - if (file_desc == files_in_current_crate.front()) { - EmitClosingOfPackageModules(file.desc().package(), file); + if (file == files_in_current_crate.front()) { + EmitClosingOfPackageModules(ctx, file->package()); } return true; } diff --git a/src/google/protobuf/compiler/rust/message.cc b/src/google/protobuf/compiler/rust/message.cc index 256cf81ffd..6f5454d1a5 100644 --- a/src/google/protobuf/compiler/rust/message.cc +++ b/src/google/protobuf/compiler/rust/message.cc @@ -24,16 +24,16 @@ namespace compiler { namespace rust { namespace { -void MessageNew(Context msg) { - switch (msg.opts().kernel) { +void MessageNew(Context& ctx, const Descriptor& msg) { + switch (ctx.opts().kernel) { case Kernel::kCpp: - msg.Emit({{"new_thunk", Thunk(msg, "new")}}, R"rs( + ctx.Emit({{"new_thunk", Thunk(ctx, msg, "new")}}, R"rs( Self { inner: $pbr$::MessageInner { msg: unsafe { $new_thunk$() } } } )rs"); return; case Kernel::kUpb: - msg.Emit({{"new_thunk", Thunk(msg, "new")}}, R"rs( + ctx.Emit({{"new_thunk", Thunk(ctx, msg, "new")}}, R"rs( let arena = $pbr$::Arena::new(); Self { inner: $pbr$::MessageInner { @@ -48,16 +48,16 @@ void MessageNew(Context msg) { ABSL_LOG(FATAL) << "unreachable"; } -void MessageSerialize(Context msg) { - switch (msg.opts().kernel) { +void MessageSerialize(Context& ctx, const Descriptor& msg) { + switch (ctx.opts().kernel) { case Kernel::kCpp: - msg.Emit({{"serialize_thunk", Thunk(msg, "serialize")}}, R"rs( + ctx.Emit({{"serialize_thunk", Thunk(ctx, msg, "serialize")}}, R"rs( unsafe { $serialize_thunk$(self.inner.msg) } )rs"); return; case Kernel::kUpb: - msg.Emit({{"serialize_thunk", Thunk(msg, "serialize")}}, R"rs( + ctx.Emit({{"serialize_thunk", Thunk(ctx, msg, "serialize")}}, R"rs( let arena = $pbr$::Arena::new(); let mut len = 0; unsafe { @@ -71,12 +71,12 @@ void MessageSerialize(Context msg) { ABSL_LOG(FATAL) << "unreachable"; } -void MessageDeserialize(Context msg) { - switch (msg.opts().kernel) { +void MessageDeserialize(Context& ctx, const Descriptor& msg) { + switch (ctx.opts().kernel) { case Kernel::kCpp: - msg.Emit( + ctx.Emit( { - {"deserialize_thunk", Thunk(msg, "deserialize")}, + {"deserialize_thunk", Thunk(ctx, msg, "deserialize")}, }, R"rs( let success = unsafe { @@ -92,7 +92,7 @@ void MessageDeserialize(Context msg) { return; case Kernel::kUpb: - msg.Emit({{"deserialize_thunk", Thunk(msg, "parse")}}, R"rs( + ctx.Emit({{"deserialize_thunk", Thunk(ctx, msg, "parse")}}, R"rs( let arena = $pbr$::Arena::new(); let msg = unsafe { $deserialize_thunk$(data.as_ptr(), data.len(), arena.raw()) @@ -115,15 +115,15 @@ void MessageDeserialize(Context msg) { ABSL_LOG(FATAL) << "unreachable"; } -void MessageExterns(Context msg) { - switch (msg.opts().kernel) { +void MessageExterns(Context& ctx, const Descriptor& msg) { + switch (ctx.opts().kernel) { case Kernel::kCpp: - msg.Emit( + ctx.Emit( { - {"new_thunk", Thunk(msg, "new")}, - {"delete_thunk", Thunk(msg, "delete")}, - {"serialize_thunk", Thunk(msg, "serialize")}, - {"deserialize_thunk", Thunk(msg, "deserialize")}, + {"new_thunk", Thunk(ctx, msg, "new")}, + {"delete_thunk", Thunk(ctx, msg, "delete")}, + {"serialize_thunk", Thunk(ctx, msg, "serialize")}, + {"deserialize_thunk", Thunk(ctx, msg, "deserialize")}, }, R"rs( fn $new_thunk$() -> $pbi$::RawMessage; @@ -134,11 +134,11 @@ void MessageExterns(Context msg) { return; case Kernel::kUpb: - msg.Emit( + ctx.Emit( { - {"new_thunk", Thunk(msg, "new")}, - {"serialize_thunk", Thunk(msg, "serialize")}, - {"deserialize_thunk", Thunk(msg, "parse")}, + {"new_thunk", Thunk(ctx, msg, "new")}, + {"serialize_thunk", Thunk(ctx, msg, "serialize")}, + {"deserialize_thunk", Thunk(ctx, msg, "parse")}, }, R"rs( fn $new_thunk$(arena: $pbi$::RawArena) -> $pbi$::RawMessage; @@ -151,36 +151,37 @@ void MessageExterns(Context msg) { ABSL_LOG(FATAL) << "unreachable"; } -void MessageDrop(Context msg) { - if (msg.is_upb()) { +void MessageDrop(Context& ctx, const Descriptor& msg) { + if (ctx.is_upb()) { // Nothing to do here; drop glue (which will run drop(self.arena) // automatically) is sufficient. return; } - msg.Emit({{"delete_thunk", Thunk(msg, "delete")}}, R"rs( + ctx.Emit({{"delete_thunk", Thunk(ctx, msg, "delete")}}, R"rs( unsafe { $delete_thunk$(self.inner.msg); } )rs"); } -void GetterForViewOrMut(Context field, bool is_mut) { - auto fieldName = field.desc().name(); - auto fieldType = field.desc().type(); - auto getter_thunk = Thunk(field, "get"); - auto setter_thunk = Thunk(field, "set"); - auto clearer_thunk = Thunk(field, "clear"); +void GetterForViewOrMut(Context& ctx, const FieldDescriptor& field, + bool is_mut) { + auto fieldName = field.name(); + auto fieldType = field.type(); + auto getter_thunk = Thunk(ctx, field, "get"); + auto setter_thunk = Thunk(ctx, field, "set"); + auto clearer_thunk = Thunk(ctx, field, "clear"); // If we're dealing with a Mut, the getter must be supplied // self.inner.msg() whereas a View has to be supplied self.msg auto self = is_mut ? "self.inner.msg()" : "self.msg"; if (fieldType == FieldDescriptor::TYPE_MESSAGE) { - Context d = field.WithDesc(field.desc().message_type()); + const Descriptor& msg = *field.message_type(); // TODO: support messages which are defined in other crates. - if (!IsInCurrentlyGeneratingCrate(d)) { + if (!IsInCurrentlyGeneratingCrate(ctx, msg)) { return; } - auto prefix = "crate::" + GetCrateRelativeQualifiedPath(d); - field.Emit( + auto prefix = "crate::" + GetCrateRelativeQualifiedPath(ctx, msg); + ctx.Emit( { {"prefix", prefix}, {"field", fieldName}, @@ -190,8 +191,8 @@ void GetterForViewOrMut(Context field, bool is_mut) { { "view_body", [&] { - if (field.is_upb()) { - field.Emit({}, R"rs( + if (ctx.is_upb()) { + ctx.Emit({}, R"rs( let submsg = unsafe { $getter_thunk$($self$) }; match submsg { None => $prefix$View::new($pbi$::Private, @@ -200,7 +201,7 @@ void GetterForViewOrMut(Context field, bool is_mut) { } )rs"); } else { - field.Emit({}, R"rs( + ctx.Emit({}, R"rs( let submsg = unsafe { $getter_thunk$($self$) }; $prefix$View::new($pbi$::Private, submsg) )rs"); @@ -216,18 +217,18 @@ void GetterForViewOrMut(Context field, bool is_mut) { return; } - auto rsType = PrimitiveRsTypeName(field.desc()); + auto rsType = PrimitiveRsTypeName(field); if (fieldType == FieldDescriptor::TYPE_STRING || fieldType == FieldDescriptor::TYPE_BYTES) { - field.Emit({{"field", fieldName}, - {"self", self}, - {"getter_thunk", getter_thunk}, - {"setter_thunk", setter_thunk}, - {"RsType", rsType}, - {"maybe_mutator", - [&] { - if (is_mut) { - field.Emit({}, R"rs( + ctx.Emit({{"field", fieldName}, + {"self", self}, + {"getter_thunk", getter_thunk}, + {"setter_thunk", setter_thunk}, + {"RsType", rsType}, + {"maybe_mutator", + [&] { + if (is_mut) { + ctx.Emit({}, R"rs( pub fn r#$field$_mut(&self) -> $pb$::Mut<'_, $RsType$> { static VTABLE: $pbi$::BytesMutVTable = $pbi$::BytesMutVTable::new( @@ -248,9 +249,9 @@ void GetterForViewOrMut(Context field, bool is_mut) { } } )rs"); - } - }}}, - R"rs( + } + }}}, + R"rs( pub fn r#$field$(&self) -> $pb$::View<'_, $RsType$> { let s = unsafe { $getter_thunk$($self$).as_ref() }; unsafe { __pb::ProtoStr::from_utf8_unchecked(s).into() } @@ -259,19 +260,19 @@ void GetterForViewOrMut(Context field, bool is_mut) { $maybe_mutator$ )rs"); } else { - field.Emit({{"field", fieldName}, - {"getter_thunk", getter_thunk}, - {"setter_thunk", setter_thunk}, - {"clearer_thunk", clearer_thunk}, - {"self", self}, - {"RsType", rsType}, - {"maybe_mutator", - [&] { - // TODO: once the rust public api is accessible, - // by tooling, ensure that this only appears for the - // mutational pathway - if (is_mut && fieldType) { - field.Emit({}, R"rs( + ctx.Emit({{"field", fieldName}, + {"getter_thunk", getter_thunk}, + {"setter_thunk", setter_thunk}, + {"clearer_thunk", clearer_thunk}, + {"self", self}, + {"RsType", rsType}, + {"maybe_mutator", + [&] { + // TODO: once the rust public api is accessible, + // by tooling, ensure that this only appears for the + // mutational pathway + if (is_mut && fieldType) { + ctx.Emit({}, R"rs( pub fn r#$field$_mut(&self) -> $pb$::Mut<'_, $RsType$> { static VTABLE: $pbi$::PrimitiveVTable<$RsType$> = $pbi$::PrimitiveVTable::new( @@ -290,9 +291,9 @@ void GetterForViewOrMut(Context field, bool is_mut) { } } )rs"); - } - }}}, - R"rs( + } + }}}, + R"rs( pub fn r#$field$(&self) -> $pb$::View<'_, $RsType$> { unsafe { $getter_thunk$($self$) } } @@ -302,93 +303,89 @@ void GetterForViewOrMut(Context field, bool is_mut) { } } -void AccessorsForViewOrMut(Context msg, bool is_mut) { - for (int i = 0; i < msg.desc().field_count(); ++i) { - auto field = msg.WithDesc(*msg.desc().field(i)); - if (field.desc().is_repeated()) continue; +void AccessorsForViewOrMut(Context& ctx, const Descriptor& msg, bool is_mut) { + for (int i = 0; i < msg.field_count(); ++i) { + const FieldDescriptor& field = *msg.field(i); + if (field.is_repeated()) continue; // TODO - add cord support - if (field.desc().options().has_ctype()) continue; + if (field.options().has_ctype()) continue; // TODO - if (field.desc().type() == FieldDescriptor::TYPE_ENUM || - field.desc().type() == FieldDescriptor::TYPE_GROUP) + if (field.type() == FieldDescriptor::TYPE_ENUM || + field.type() == FieldDescriptor::TYPE_GROUP) continue; - GetterForViewOrMut(field, is_mut); - msg.printer().PrintRaw("\n"); + GetterForViewOrMut(ctx, field, is_mut); + ctx.printer().PrintRaw("\n"); } } } // namespace -void GenerateRs(Context msg) { - if (msg.desc().map_key() != nullptr) { - ABSL_LOG(WARNING) << "unsupported map field: " << msg.desc().full_name(); +void GenerateRs(Context& ctx, const Descriptor& msg) { + if (msg.map_key() != nullptr) { + ABSL_LOG(WARNING) << "unsupported map field: " << msg.full_name(); return; } - msg.Emit( - {{"Msg", msg.desc().name()}, - {"Msg::new", [&] { MessageNew(msg); }}, - {"Msg::serialize", [&] { MessageSerialize(msg); }}, - {"Msg::deserialize", [&] { MessageDeserialize(msg); }}, - {"Msg::drop", [&] { MessageDrop(msg); }}, - {"Msg_externs", [&] { MessageExterns(msg); }}, - {"accessor_fns", - [&] { - for (int i = 0; i < msg.desc().field_count(); ++i) { - auto field = msg.WithDesc(*msg.desc().field(i)); - msg.Emit({{"comment", FieldInfoComment(field)}}, R"rs( + ctx.Emit({{"Msg", msg.name()}, + {"Msg::new", [&] { MessageNew(ctx, msg); }}, + {"Msg::serialize", [&] { MessageSerialize(ctx, msg); }}, + {"Msg::deserialize", [&] { MessageDeserialize(ctx, msg); }}, + {"Msg::drop", [&] { MessageDrop(ctx, msg); }}, + {"Msg_externs", [&] { MessageExterns(ctx, msg); }}, + {"accessor_fns", + [&] { + for (int i = 0; i < msg.field_count(); ++i) { + auto& field = *msg.field(i); + ctx.Emit({{"comment", FieldInfoComment(ctx, field)}}, R"rs( // $comment$ )rs"); - GenerateAccessorMsgImpl(field); - msg.printer().PrintRaw("\n"); - } - }}, - {"oneof_accessor_fns", - [&] { - for (int i = 0; i < msg.desc().real_oneof_decl_count(); ++i) { - GenerateOneofAccessors( - msg.WithDesc(*msg.desc().real_oneof_decl(i))); - msg.printer().PrintRaw("\n"); - } - }}, - {"accessor_externs", - [&] { - for (int i = 0; i < msg.desc().field_count(); ++i) { - GenerateAccessorExternC(msg.WithDesc(*msg.desc().field(i))); - msg.printer().PrintRaw("\n"); - } - }}, - {"oneof_externs", - [&] { - for (int i = 0; i < msg.desc().real_oneof_decl_count(); ++i) { - GenerateOneofExternC(msg.WithDesc(*msg.desc().real_oneof_decl(i))); - msg.printer().PrintRaw("\n"); - } - }}, - {"nested_msgs", - [&] { - // If we have no nested types or oneofs, bail out without emitting - // an empty mod SomeMsg_. - if (msg.desc().nested_type_count() == 0 && - msg.desc().real_oneof_decl_count() == 0) { - return; - } - msg.Emit( - {{"Msg", msg.desc().name()}, - {"nested_msgs", - [&] { - for (int i = 0; i < msg.desc().nested_type_count(); ++i) { - auto nested_msg = msg.WithDesc(msg.desc().nested_type(i)); - GenerateRs(nested_msg); - } - }}, - {"oneofs", - [&] { - for (int i = 0; i < msg.desc().real_oneof_decl_count(); ++i) { - GenerateOneofDefinition( - msg.WithDesc(*msg.desc().real_oneof_decl(i))); - } - }}}, - R"rs( + GenerateAccessorMsgImpl(ctx, field); + ctx.printer().PrintRaw("\n"); + } + }}, + {"oneof_accessor_fns", + [&] { + for (int i = 0; i < msg.real_oneof_decl_count(); ++i) { + GenerateOneofAccessors(ctx, *msg.real_oneof_decl(i)); + ctx.printer().PrintRaw("\n"); + } + }}, + {"accessor_externs", + [&] { + for (int i = 0; i < msg.field_count(); ++i) { + GenerateAccessorExternC(ctx, *msg.field(i)); + ctx.printer().PrintRaw("\n"); + } + }}, + {"oneof_externs", + [&] { + for (int i = 0; i < msg.real_oneof_decl_count(); ++i) { + GenerateOneofExternC(ctx, *msg.real_oneof_decl(i)); + ctx.printer().PrintRaw("\n"); + } + }}, + {"nested_msgs", + [&] { + // If we have no nested types or oneofs, bail out without + // emitting an empty mod SomeMsg_. + if (msg.nested_type_count() == 0 && + msg.real_oneof_decl_count() == 0) { + return; + } + ctx.Emit( + {{"Msg", msg.name()}, + {"nested_msgs", + [&] { + for (int i = 0; i < msg.nested_type_count(); ++i) { + GenerateRs(ctx, *msg.nested_type(i)); + } + }}, + {"oneofs", + [&] { + for (int i = 0; i < msg.real_oneof_decl_count(); ++i) { + GenerateOneofDefinition(ctx, *msg.real_oneof_decl(i)); + } + }}}, + R"rs( #[allow(non_snake_case)] pub mod $Msg$_ { $nested_msgs$ @@ -396,10 +393,12 @@ void GenerateRs(Context msg) { $oneofs$ } // mod $Msg$_ )rs"); - }}, - {"accessor_fns_for_views", [&] { AccessorsForViewOrMut(msg, false); }}, - {"accessor_fns_for_muts", [&] { AccessorsForViewOrMut(msg, true); }}}, - R"rs( + }}, + {"accessor_fns_for_views", + [&] { AccessorsForViewOrMut(ctx, msg, false); }}, + {"accessor_fns_for_muts", + [&] { AccessorsForViewOrMut(ctx, msg, true); }}}, + R"rs( #[allow(non_camel_case_types)] // TODO: Implement support for debug redaction #[derive(Debug)] @@ -542,9 +541,9 @@ void GenerateRs(Context msg) { $nested_msgs$ )rs"); - if (msg.is_cpp()) { - msg.printer().PrintRaw("\n"); - msg.Emit({{"Msg", msg.desc().name()}}, R"rs( + if (ctx.is_cpp()) { + ctx.printer().PrintRaw("\n"); + ctx.Emit({{"Msg", msg.name()}}, R"rs( impl $Msg$ { pub fn __unstable_wrap_cpp_grant_permission_to_break(msg: $pbi$::RawMessage) -> Self { Self { inner: $pbr$::MessageInner { msg } } @@ -558,39 +557,37 @@ void GenerateRs(Context msg) { } // Generates code for a particular message in `.pb.thunk.cc`. -void GenerateThunksCc(Context msg) { - ABSL_CHECK(msg.is_cpp()); - if (msg.desc().map_key() != nullptr) { - ABSL_LOG(WARNING) << "unsupported map field: " << msg.desc().full_name(); +void GenerateThunksCc(Context& ctx, const Descriptor& msg) { + ABSL_CHECK(ctx.is_cpp()); + if (msg.map_key() != nullptr) { + ABSL_LOG(WARNING) << "unsupported map field: " << msg.full_name(); return; } - msg.Emit( + ctx.Emit( {{"abi", "\"C\""}, // Workaround for syntax highlight bug in VSCode. - {"Msg", msg.desc().name()}, - {"QualifiedMsg", cpp::QualifiedClassName(&msg.desc())}, - {"new_thunk", Thunk(msg, "new")}, - {"delete_thunk", Thunk(msg, "delete")}, - {"serialize_thunk", Thunk(msg, "serialize")}, - {"deserialize_thunk", Thunk(msg, "deserialize")}, + {"Msg", msg.name()}, + {"QualifiedMsg", cpp::QualifiedClassName(&msg)}, + {"new_thunk", Thunk(ctx, msg, "new")}, + {"delete_thunk", Thunk(ctx, msg, "delete")}, + {"serialize_thunk", Thunk(ctx, msg, "serialize")}, + {"deserialize_thunk", Thunk(ctx, msg, "deserialize")}, {"nested_msg_thunks", [&] { - for (int i = 0; i < msg.desc().nested_type_count(); ++i) { - Context nested_msg = - msg.WithDesc(msg.desc().nested_type(i)); - GenerateThunksCc(nested_msg); + for (int i = 0; i < msg.nested_type_count(); ++i) { + GenerateThunksCc(ctx, *msg.nested_type(i)); } }}, {"accessor_thunks", [&] { - for (int i = 0; i < msg.desc().field_count(); ++i) { - GenerateAccessorThunkCc(msg.WithDesc(*msg.desc().field(i))); + for (int i = 0; i < msg.field_count(); ++i) { + GenerateAccessorThunkCc(ctx, *msg.field(i)); } }}, {"oneof_thunks", [&] { - for (int i = 0; i < msg.desc().real_oneof_decl_count(); ++i) { - GenerateOneofThunkCc(msg.WithDesc(*msg.desc().real_oneof_decl(i))); + for (int i = 0; i < msg.real_oneof_decl_count(); ++i) { + GenerateOneofThunkCc(ctx, *msg.real_oneof_decl(i)); } }}}, R"cc( diff --git a/src/google/protobuf/compiler/rust/message.h b/src/google/protobuf/compiler/rust/message.h index 4d1c392f70..e2734eb88c 100644 --- a/src/google/protobuf/compiler/rust/message.h +++ b/src/google/protobuf/compiler/rust/message.h @@ -21,10 +21,10 @@ namespace compiler { namespace rust { // Generates code for a particular message in `.pb.rs`. -void GenerateRs(Context msg); +void GenerateRs(Context& ctx, const Descriptor& msg); // Generates code for a particular message in `.pb.thunk.cc`. -void GenerateThunksCc(Context msg); +void GenerateThunksCc(Context& ctx, const Descriptor& msg); } // namespace rust } // namespace compiler diff --git a/src/google/protobuf/compiler/rust/naming.cc b/src/google/protobuf/compiler/rust/naming.cc index 6685493c04..12fb0891a4 100644 --- a/src/google/protobuf/compiler/rust/naming.cc +++ b/src/google/protobuf/compiler/rust/naming.cc @@ -26,22 +26,23 @@ namespace protobuf { namespace compiler { namespace rust { namespace { -std::string GetUnderscoreDelimitedFullName(Context msg) { - std::string result = msg.desc().full_name(); +std::string GetUnderscoreDelimitedFullName(Context& ctx, + const Descriptor& msg) { + std::string result = msg.full_name(); absl::StrReplaceAll({{".", "_"}}, &result); return result; } } // namespace -std::string GetCrateName(Context dep) { - absl::string_view path = dep.desc().name(); +std::string GetCrateName(Context& ctx, const FileDescriptor& dep) { + absl::string_view path = dep.name(); auto basename = path.substr(path.rfind('/') + 1); return absl::StrReplaceAll(basename, {{".", "_"}, {"-", "_"}}); } -std::string GetRsFile(Context file) { - auto basename = StripProto(file.desc().name()); - switch (auto k = file.opts().kernel) { +std::string GetRsFile(Context& ctx, const FileDescriptor& file) { + auto basename = StripProto(file.name()); + switch (auto k = ctx.opts().kernel) { case Kernel::kUpb: return absl::StrCat(basename, ".u.pb.rs"); case Kernel::kCpp: @@ -52,42 +53,41 @@ std::string GetRsFile(Context file) { } } -std::string GetThunkCcFile(Context file) { - auto basename = StripProto(file.desc().name()); +std::string GetThunkCcFile(Context& ctx, const FileDescriptor& file) { + auto basename = StripProto(file.name()); return absl::StrCat(basename, ".pb.thunks.cc"); } -std::string GetHeaderFile(Context file) { - auto basename = StripProto(file.desc().name()); +std::string GetHeaderFile(Context& ctx, const FileDescriptor& file) { + auto basename = StripProto(file.name()); return absl::StrCat(basename, ".proto.h"); } namespace { template -std::string FieldPrefix(Context field) { - // NOTE: When field.is_upb(), this functions outputs must match the symbols +std::string FieldPrefix(Context& ctx, const T& field) { + // NOTE: When ctx.is_upb(), this functions outputs must match the symbols // that the upbc plugin generates exactly. Failure to do so correctly results // in a link-time failure. - absl::string_view prefix = field.is_cpp() ? "__rust_proto_thunk__" : ""; - std::string thunk_prefix = - absl::StrCat(prefix, GetUnderscoreDelimitedFullName( - field.WithDesc(field.desc().containing_type()))); + absl::string_view prefix = ctx.is_cpp() ? "__rust_proto_thunk__" : ""; + std::string thunk_prefix = absl::StrCat( + prefix, GetUnderscoreDelimitedFullName(ctx, *field.containing_type())); return thunk_prefix; } template -std::string Thunk(Context field, absl::string_view op) { - std::string thunk = FieldPrefix(field); +std::string Thunk(Context& ctx, const T& field, absl::string_view op) { + std::string thunk = FieldPrefix(ctx, field); absl::string_view format; - if (field.is_upb() && op == "get") { + if (ctx.is_upb() && op == "get") { // upb getter is simply the field name (no "get" in the name). format = "_$1"; - } else if (field.is_upb() && op == "get_mut") { + } else if (ctx.is_upb() && op == "get_mut") { // same as above, with with `mutable` prefix format = "_mutable_$1"; - } else if (field.is_upb() && op == "case") { + } else if (ctx.is_upb() && op == "case") { // some upb functions are in the order x_op compared to has/set/clear which // are in the other order e.g. op_x. format = "_$1_$0"; @@ -95,51 +95,53 @@ std::string Thunk(Context field, absl::string_view op) { format = "_$0_$1"; } - absl::SubstituteAndAppend(&thunk, format, op, field.desc().name()); + absl::SubstituteAndAppend(&thunk, format, op, field.name()); return thunk; } -std::string ThunkMapOrRepeated(Context field, +std::string ThunkMapOrRepeated(Context& ctx, const FieldDescriptor& field, absl::string_view op) { - if (!field.is_upb()) { - return Thunk(field, op); + if (!ctx.is_upb()) { + return Thunk(ctx, field, op); } - std::string thunk = absl::StrCat("_", FieldPrefix(field)); + std::string thunk = absl::StrCat("_", FieldPrefix(ctx, field)); absl::string_view format; if (op == "get") { - format = field.desc().is_map() ? "_$1_upb_map" : "_$1_upb_array"; + format = field.is_map() ? "_$1_upb_map" : "_$1_upb_array"; } else if (op == "get_mut") { - format = - field.desc().is_map() ? "_$1_mutable_upb_map" : "_$1_mutable_upb_array"; + format = field.is_map() ? "_$1_mutable_upb_map" : "_$1_mutable_upb_array"; } else { - return Thunk(field, op); + return Thunk(ctx, field, op); } - absl::SubstituteAndAppend(&thunk, format, op, field.desc().name()); + absl::SubstituteAndAppend(&thunk, format, op, field.name()); return thunk; } } // namespace -std::string Thunk(Context field, absl::string_view op) { - if (field.desc().is_map() || field.desc().is_repeated()) { - return ThunkMapOrRepeated(field, op); +std::string Thunk(Context& ctx, const FieldDescriptor& field, + absl::string_view op) { + if (field.is_map() || field.is_repeated()) { + return ThunkMapOrRepeated(ctx, field, op); } - return Thunk(field, op); + return Thunk(ctx, field, op); } -std::string Thunk(Context field, absl::string_view op) { - return Thunk(field, op); +std::string Thunk(Context& ctx, const OneofDescriptor& field, + absl::string_view op) { + return Thunk(ctx, field, op); } -std::string Thunk(Context msg, absl::string_view op) { - absl::string_view prefix = msg.is_cpp() ? "__rust_proto_thunk__" : ""; - return absl::StrCat(prefix, GetUnderscoreDelimitedFullName(msg), "_", op); +std::string Thunk(Context& ctx, const Descriptor& msg, absl::string_view op) { + absl::string_view prefix = ctx.is_cpp() ? "__rust_proto_thunk__" : ""; + return absl::StrCat(prefix, GetUnderscoreDelimitedFullName(ctx, msg), "_", + op); } -std::string PrimitiveRsTypeName(const FieldDescriptor& desc) { - switch (desc.type()) { +std::string PrimitiveRsTypeName(const FieldDescriptor& field) { + switch (field.type()) { case FieldDescriptor::TYPE_BOOL: return "bool"; case FieldDescriptor::TYPE_INT32: @@ -167,7 +169,7 @@ std::string PrimitiveRsTypeName(const FieldDescriptor& desc) { default: break; } - ABSL_LOG(FATAL) << "Unsupported field type: " << desc.type_name(); + ABSL_LOG(FATAL) << "Unsupported field type: " << field.type_name(); return ""; } @@ -180,20 +182,18 @@ std::string PrimitiveRsTypeName(const FieldDescriptor& desc) { // // If the message has no package and no containing messages then this returns // empty string. -std::string RustModule(Context msg) { - const Descriptor& desc = msg.desc(); - +std::string RustModule(Context& ctx, const Descriptor& msg) { std::vector modules; std::vector package_modules = - absl::StrSplit(desc.file()->package(), '.', absl::SkipEmpty()); + absl::StrSplit(msg.file()->package(), '.', absl::SkipEmpty()); modules.insert(modules.begin(), package_modules.begin(), package_modules.end()); // Innermost to outermost order. std::vector modules_from_containing_types; - const Descriptor* parent = desc.containing_type(); + const Descriptor* parent = msg.containing_type(); while (parent != nullptr) { modules_from_containing_types.push_back(absl::StrCat(parent->name(), "_")); parent = parent->containing_type(); @@ -213,27 +213,25 @@ std::string RustModule(Context msg) { return absl::StrJoin(modules, "::"); } -std::string RustInternalModuleName(Context file) { +std::string RustInternalModuleName(Context& ctx, const FileDescriptor& file) { // TODO: Introduce a more robust mangling here to avoid conflicts // between `foo/bar/baz.proto` and `foo_bar/baz.proto`. - return absl::StrReplaceAll(StripProto(file.desc().name()), {{"/", "_"}}); + return absl::StrReplaceAll(StripProto(file.name()), {{"/", "_"}}); } -std::string GetCrateRelativeQualifiedPath(Context msg) { - return absl::StrCat(RustModule(msg), msg.desc().name()); +std::string GetCrateRelativeQualifiedPath(Context& ctx, const Descriptor& msg) { + return absl::StrCat(RustModule(ctx, msg), msg.name()); } -std::string FieldInfoComment(Context field) { - absl::string_view label = - field.desc().is_repeated() ? "repeated" : "optional"; - std::string comment = - absl::StrCat(field.desc().name(), ": ", label, " ", - FieldDescriptor::TypeName(field.desc().type())); +std::string FieldInfoComment(Context& ctx, const FieldDescriptor& field) { + absl::string_view label = field.is_repeated() ? "repeated" : "optional"; + std::string comment = absl::StrCat(field.name(), ": ", label, " ", + FieldDescriptor::TypeName(field.type())); - if (auto* m = field.desc().message_type()) { + if (auto* m = field.message_type()) { absl::StrAppend(&comment, " ", m->full_name()); } - if (auto* m = field.desc().enum_type()) { + if (auto* m = field.enum_type()) { absl::StrAppend(&comment, " ", m->full_name()); } diff --git a/src/google/protobuf/compiler/rust/naming.h b/src/google/protobuf/compiler/rust/naming.h index 2bb5cc539c..e6f7b7f970 100644 --- a/src/google/protobuf/compiler/rust/naming.h +++ b/src/google/protobuf/compiler/rust/naming.h @@ -19,25 +19,27 @@ namespace google { namespace protobuf { namespace compiler { namespace rust { -std::string GetCrateName(Context dep); +std::string GetCrateName(Context& ctx, const FileDescriptor& dep); -std::string GetRsFile(Context file); -std::string GetThunkCcFile(Context file); -std::string GetHeaderFile(Context file); +std::string GetRsFile(Context& ctx, const FileDescriptor& file); +std::string GetThunkCcFile(Context& ctx, const FileDescriptor& file); +std::string GetHeaderFile(Context& ctx, const FileDescriptor& file); -std::string Thunk(Context field, absl::string_view op); -std::string Thunk(Context field, absl::string_view op); +std::string Thunk(Context& ctx, const FieldDescriptor& field, + absl::string_view op); +std::string Thunk(Context& ctx, const OneofDescriptor& field, + absl::string_view op); -std::string Thunk(Context msg, absl::string_view op); +std::string Thunk(Context& ctx, const Descriptor& msg, absl::string_view op); -std::string PrimitiveRsTypeName(const FieldDescriptor& desc); +std::string PrimitiveRsTypeName(const FieldDescriptor& field); -std::string FieldInfoComment(Context field); +std::string FieldInfoComment(Context& ctx, const FieldDescriptor& field); -std::string RustModule(Context msg); -std::string RustInternalModuleName(Context file); +std::string RustModule(Context& ctx, const Descriptor& msg); +std::string RustInternalModuleName(Context& ctx, const FileDescriptor& file); -std::string GetCrateRelativeQualifiedPath(Context msg); +std::string GetCrateRelativeQualifiedPath(Context& ctx, const Descriptor& msg); } // namespace rust } // namespace compiler diff --git a/src/google/protobuf/compiler/rust/oneof.cc b/src/google/protobuf/compiler/rust/oneof.cc index d6166ae6d6..ce6e3af3c5 100644 --- a/src/google/protobuf/compiler/rust/oneof.cc +++ b/src/google/protobuf/compiler/rust/oneof.cc @@ -78,27 +78,25 @@ std::string ToCamelCase(absl::string_view name) { return cpp::UnderscoresToCamelCase(name, /* upper initial letter */ true); } -std::string oneofViewEnumRsName(const OneofDescriptor& desc) { - return ToCamelCase(desc.name()); +std::string oneofViewEnumRsName(const OneofDescriptor& oneof) { + return ToCamelCase(oneof.name()); } -std::string oneofMutEnumRsName(const OneofDescriptor& desc) { - return ToCamelCase(desc.name()) + "Mut"; +std::string oneofMutEnumRsName(const OneofDescriptor& oneof) { + return ToCamelCase(oneof.name()) + "Mut"; } -std::string oneofCaseEnumName(const OneofDescriptor& desc) { +std::string oneofCaseEnumName(const OneofDescriptor& oneof) { // Note: This is the name used for the cpp Case enum, we use it for both // the Rust Case enum as well as for the cpp case enum in the cpp thunk. - return ToCamelCase(desc.name()) + "Case"; + return ToCamelCase(oneof.name()) + "Case"; } -std::string RsTypeNameView(Context field) { - const auto& desc = field.desc(); - - if (desc.options().has_ctype()) { +std::string RsTypeNameView(Context& ctx, const FieldDescriptor& field) { + if (field.options().has_ctype()) { return ""; // TODO: b/308792377 - ctype fields not supported yet. } - switch (desc.type()) { + switch (field.type()) { case FieldDescriptor::TYPE_INT32: case FieldDescriptor::TYPE_INT64: case FieldDescriptor::TYPE_FIXED32: @@ -112,7 +110,7 @@ std::string RsTypeNameView(Context field) { case FieldDescriptor::TYPE_FLOAT: case FieldDescriptor::TYPE_DOUBLE: case FieldDescriptor::TYPE_BOOL: - return PrimitiveRsTypeName(desc); + return PrimitiveRsTypeName(field); case FieldDescriptor::TYPE_BYTES: return "&'msg [u8]"; case FieldDescriptor::TYPE_STRING: @@ -120,23 +118,21 @@ std::string RsTypeNameView(Context field) { case FieldDescriptor::TYPE_MESSAGE: return absl::StrCat( "::__pb::View<'msg, crate::", - GetCrateRelativeQualifiedPath(field.WithDesc(desc.message_type())), - ">"); + GetCrateRelativeQualifiedPath(ctx, *field.message_type()), ">"); case FieldDescriptor::TYPE_ENUM: // TODO: b/300257770 - Support enums. case FieldDescriptor::TYPE_GROUP: // Not supported yet. return ""; } - ABSL_LOG(FATAL) << "Unexpected field type: " << desc.type_name(); + ABSL_LOG(FATAL) << "Unexpected field type: " << field.type_name(); return ""; } -std::string RsTypeNameMut(Context field) { - const auto& desc = field.desc(); - if (desc.options().has_ctype()) { +std::string RsTypeNameMut(Context& ctx, const FieldDescriptor& field) { + if (field.options().has_ctype()) { return ""; // TODO: b/308792377 - ctype fields not supported yet. } - switch (desc.type()) { + switch (field.type()) { case FieldDescriptor::TYPE_INT32: case FieldDescriptor::TYPE_INT64: case FieldDescriptor::TYPE_FIXED32: @@ -152,56 +148,54 @@ std::string RsTypeNameMut(Context field) { case FieldDescriptor::TYPE_BOOL: case FieldDescriptor::TYPE_BYTES: case FieldDescriptor::TYPE_STRING: - return absl::StrCat("::__pb::Mut<'msg, ", PrimitiveRsTypeName(desc), ">"); + return absl::StrCat("::__pb::Mut<'msg, ", PrimitiveRsTypeName(field), + ">"); case FieldDescriptor::TYPE_MESSAGE: return absl::StrCat( "::__pb::Mut<'msg, crate::", - GetCrateRelativeQualifiedPath(field.WithDesc(desc.message_type())), - ">"); + GetCrateRelativeQualifiedPath(ctx, *field.message_type()), ">"); case FieldDescriptor::TYPE_ENUM: // TODO: b/300257770 - Support enums. case FieldDescriptor::TYPE_GROUP: // Not supported yet. return ""; } - ABSL_LOG(FATAL) << "Unexpected field type: " << desc.type_name(); + ABSL_LOG(FATAL) << "Unexpected field type: " << field.type_name(); return ""; } } // namespace -void GenerateOneofDefinition(Context oneof) { - const auto& desc = oneof.desc(); - - oneof.Emit( - {{"view_enum_name", oneofViewEnumRsName(desc)}, - {"mut_enum_name", oneofMutEnumRsName(desc)}, +void GenerateOneofDefinition(Context& ctx, const OneofDescriptor& oneof) { + ctx.Emit( + {{"view_enum_name", oneofViewEnumRsName(oneof)}, + {"mut_enum_name", oneofMutEnumRsName(oneof)}, {"view_fields", [&] { - for (int i = 0; i < desc.field_count(); ++i) { - const auto& field = *desc.field(i); - std::string rs_type = RsTypeNameView(oneof.WithDesc(field)); + for (int i = 0; i < oneof.field_count(); ++i) { + auto& field = *oneof.field(i); + std::string rs_type = RsTypeNameView(ctx, field); if (rs_type.empty()) { continue; } - oneof.Emit({{"name", ToCamelCase(field.name())}, - {"type", rs_type}, - {"number", std::to_string(field.number())}}, - R"rs($name$($type$) = $number$, + ctx.Emit({{"name", ToCamelCase(field.name())}, + {"type", rs_type}, + {"number", std::to_string(field.number())}}, + R"rs($name$($type$) = $number$, )rs"); } }}, {"mut_fields", [&] { - for (int i = 0; i < desc.field_count(); ++i) { - const auto& field = *desc.field(i); - std::string rs_type = RsTypeNameMut(oneof.WithDesc(field)); + for (int i = 0; i < oneof.field_count(); ++i) { + auto& field = *oneof.field(i); + std::string rs_type = RsTypeNameMut(ctx, field); if (rs_type.empty()) { continue; } - oneof.Emit({{"name", ToCamelCase(field.name())}, - {"type", rs_type}, - {"number", std::to_string(field.number())}}, - R"rs($name$($type$) = $number$, + ctx.Emit({{"name", ToCamelCase(field.name())}, + {"type", rs_type}, + {"number", std::to_string(field.number())}}, + R"rs($name$($type$) = $number$, )rs"); } }}}, @@ -236,18 +230,18 @@ void GenerateOneofDefinition(Context oneof) { // Note: This enum is used as the Thunk return type for getting which case is // used: it exactly matches the generate case enum that both cpp and upb use. - oneof.Emit({{"case_enum_name", oneofCaseEnumName(desc)}, - {"cases", - [&] { - for (int i = 0; i < desc.field_count(); ++i) { - const auto& field = desc.field(i); - oneof.Emit({{"name", ToCamelCase(field->name())}, - {"number", std::to_string(field->number())}}, - R"rs($name$ = $number$, + ctx.Emit({{"case_enum_name", oneofCaseEnumName(oneof)}, + {"cases", + [&] { + for (int i = 0; i < oneof.field_count(); ++i) { + auto& field = *oneof.field(i); + ctx.Emit({{"name", ToCamelCase(field.name())}, + {"number", std::to_string(field.number())}}, + R"rs($name$ = $number$, )rs"); - } - }}}, - R"rs( + } + }}}, + R"rs( #[repr(C)] #[derive(Debug, Copy, Clone, PartialEq, Eq)] pub(super) enum $case_enum_name$ { @@ -260,23 +254,21 @@ void GenerateOneofDefinition(Context oneof) { )rs"); } -void GenerateOneofAccessors(Context oneof) { - const auto& desc = oneof.desc(); - - oneof.Emit( - {{"oneof_name", desc.name()}, - {"view_enum_name", oneofViewEnumRsName(desc)}, - {"mut_enum_name", oneofMutEnumRsName(desc)}, - {"case_enum_name", oneofCaseEnumName(desc)}, +void GenerateOneofAccessors(Context& ctx, const OneofDescriptor& oneof) { + ctx.Emit( + {{"oneof_name", oneof.name()}, + {"view_enum_name", oneofViewEnumRsName(oneof)}, + {"mut_enum_name", oneofMutEnumRsName(oneof)}, + {"case_enum_name", oneofCaseEnumName(oneof)}, {"view_cases", [&] { - for (int i = 0; i < desc.field_count(); ++i) { - const auto& field = *desc.field(i); - std::string rs_type = RsTypeNameView(oneof.WithDesc(field)); + for (int i = 0; i < oneof.field_count(); ++i) { + auto& field = *oneof.field(i); + std::string rs_type = RsTypeNameView(ctx, field); if (rs_type.empty()) { continue; } - oneof.Emit( + ctx.Emit( { {"case", ToCamelCase(field.name())}, {"rs_getter", field.name()}, @@ -288,13 +280,13 @@ void GenerateOneofAccessors(Context oneof) { }}, {"mut_cases", [&] { - for (int i = 0; i < desc.field_count(); ++i) { - const auto& field = *desc.field(i); - std::string rs_type = RsTypeNameMut(oneof.WithDesc(field)); + for (int i = 0; i < oneof.field_count(); ++i) { + auto& field = *oneof.field(i); + std::string rs_type = RsTypeNameMut(ctx, field); if (rs_type.empty()) { continue; } - oneof.Emit( + ctx.Emit( {{"case", ToCamelCase(field.name())}, {"rs_mut_getter", field.name() + "_mut"}, {"type", rs_type}, @@ -321,7 +313,7 @@ void GenerateOneofAccessors(Context oneof) { $Msg$_::$mut_enum_name$::$case$(self.$rs_mut_getter$()$into_mut_transform$), )rs"); } }}, - {"case_thunk", Thunk(oneof, "case")}}, + {"case_thunk", Thunk(ctx, oneof, "case")}}, R"rs( pub fn r#$oneof_name$(&self) -> $Msg$_::$view_enum_name$ { match unsafe { $case_thunk$(self.inner.msg) } { @@ -340,26 +332,24 @@ void GenerateOneofAccessors(Context oneof) { )rs"); } -void GenerateOneofExternC(Context oneof) { - const auto& desc = oneof.desc(); - oneof.Emit( +void GenerateOneofExternC(Context& ctx, const OneofDescriptor& oneof) { + ctx.Emit( { - {"case_enum_rs_name", oneofCaseEnumName(desc)}, - {"case_thunk", Thunk(oneof, "case")}, + {"case_enum_rs_name", oneofCaseEnumName(oneof)}, + {"case_thunk", Thunk(ctx, oneof, "case")}, }, R"rs( fn $case_thunk$(raw_msg: $pbi$::RawMessage) -> $Msg$_::$case_enum_rs_name$; )rs"); } -void GenerateOneofThunkCc(Context oneof) { - const auto& desc = oneof.desc(); - oneof.Emit( +void GenerateOneofThunkCc(Context& ctx, const OneofDescriptor& oneof) { + ctx.Emit( { - {"oneof_name", desc.name()}, - {"case_enum_name", oneofCaseEnumName(desc)}, - {"case_thunk", Thunk(oneof, "case")}, - {"QualifiedMsg", cpp::QualifiedClassName(desc.containing_type())}, + {"oneof_name", oneof.name()}, + {"case_enum_name", oneofCaseEnumName(oneof)}, + {"case_thunk", Thunk(ctx, oneof, "case")}, + {"QualifiedMsg", cpp::QualifiedClassName(oneof.containing_type())}, }, R"cc( $QualifiedMsg$::$case_enum_name$ $case_thunk$($QualifiedMsg$* msg) { diff --git a/src/google/protobuf/compiler/rust/oneof.h b/src/google/protobuf/compiler/rust/oneof.h index 98e2bc99f6..7ad2143f83 100644 --- a/src/google/protobuf/compiler/rust/oneof.h +++ b/src/google/protobuf/compiler/rust/oneof.h @@ -16,10 +16,10 @@ namespace protobuf { namespace compiler { namespace rust { -void GenerateOneofDefinition(Context oneof); -void GenerateOneofAccessors(Context oneof); -void GenerateOneofExternC(Context oneof); -void GenerateOneofThunkCc(Context oneof); +void GenerateOneofDefinition(Context& ctx, const OneofDescriptor& oneof); +void GenerateOneofAccessors(Context& ctx, const OneofDescriptor& oneof); +void GenerateOneofExternC(Context& ctx, const OneofDescriptor& oneof); +void GenerateOneofThunkCc(Context& ctx, const OneofDescriptor& oneof); } // namespace rust } // namespace compiler