Decouple Context from the Descriptor

PiperOrigin-RevId: 592029759
pull/15107/head
Protobuf Team Bot 1 year ago committed by Copybara-Service
parent 52ee619733
commit 542ca772fa
  1. 63
      src/google/protobuf/compiler/rust/accessors/accessor_generator.h
  2. 35
      src/google/protobuf/compiler/rust/accessors/accessors.cc
  3. 6
      src/google/protobuf/compiler/rust/accessors/accessors.h
  4. 42
      src/google/protobuf/compiler/rust/accessors/helpers.cc
  5. 3
      src/google/protobuf/compiler/rust/accessors/helpers.h
  6. 87
      src/google/protobuf/compiler/rust/accessors/map.cc
  7. 124
      src/google/protobuf/compiler/rust/accessors/repeated_scalar.cc
  8. 85
      src/google/protobuf/compiler/rust/accessors/singular_message.cc
  9. 112
      src/google/protobuf/compiler/rust/accessors/singular_scalar.cc
  10. 139
      src/google/protobuf/compiler/rust/accessors/singular_string.cc
  11. 7
      src/google/protobuf/compiler/rust/accessors/unsupported_field.cc
  12. 9
      src/google/protobuf/compiler/rust/context.cc
  13. 42
      src/google/protobuf/compiler/rust/context.h
  14. 165
      src/google/protobuf/compiler/rust/generator.cc
  15. 337
      src/google/protobuf/compiler/rust/message.cc
  16. 4
      src/google/protobuf/compiler/rust/message.h
  17. 120
      src/google/protobuf/compiler/rust/naming.cc
  18. 26
      src/google/protobuf/compiler/rust/naming.h
  19. 156
      src/google/protobuf/compiler/rust/oneof.cc
  20. 8
      src/google/protobuf/compiler/rust/oneof.h

@ -26,25 +26,26 @@ class AccessorGenerator {
AccessorGenerator() = default; AccessorGenerator() = default;
virtual ~AccessorGenerator() = default; virtual ~AccessorGenerator() = default;
AccessorGenerator(const AccessorGenerator &) = delete; AccessorGenerator(const AccessorGenerator&) = delete;
AccessorGenerator(AccessorGenerator &&) = delete; AccessorGenerator(AccessorGenerator&&) = delete;
AccessorGenerator &operator=(const AccessorGenerator &) = delete; AccessorGenerator& operator=(const AccessorGenerator&) = delete;
AccessorGenerator &operator=(AccessorGenerator &&) = delete; AccessorGenerator& operator=(AccessorGenerator&&) = delete;
// Constructs a generator for the given field. // Constructs a generator for the given field.
// //
// Returns `nullptr` if there is no known generator for this field. // Returns `nullptr` if there is no known generator for this field.
static std::unique_ptr<AccessorGenerator> For(Context<FieldDescriptor> field); static std::unique_ptr<AccessorGenerator> For(Context& ctx,
const FieldDescriptor& field);
void GenerateMsgImpl(Context<FieldDescriptor> field) const { void GenerateMsgImpl(Context& ctx, const FieldDescriptor& field) const {
InMsgImpl(field); InMsgImpl(ctx, field);
} }
void GenerateExternC(Context<FieldDescriptor> field) const { void GenerateExternC(Context& ctx, const FieldDescriptor& field) const {
InExternC(field); InExternC(ctx, field);
} }
void GenerateThunkCc(Context<FieldDescriptor> field) const { void GenerateThunkCc(Context& ctx, const FieldDescriptor& field) const {
ABSL_CHECK(field.is_cpp()); ABSL_CHECK(ctx.is_cpp());
InThunkCc(field); InThunkCc(ctx, field);
} }
private: private:
@ -54,53 +55,53 @@ class AccessorGenerator {
// prologue to inject variables automatically. // prologue to inject variables automatically.
// Called inside the main inherent `impl Msg {}` block. // Called inside the main inherent `impl Msg {}` block.
virtual void InMsgImpl(Context<FieldDescriptor> field) const {} virtual void InMsgImpl(Context& ctx, const FieldDescriptor& field) const {}
// Called inside of a message's `extern "C" {}` block. // Called inside of a message's `extern "C" {}` block.
virtual void InExternC(Context<FieldDescriptor> field) const {} virtual void InExternC(Context& ctx, const FieldDescriptor& field) const {}
// Called inside of an `extern "C" {}` block in the `.thunk.cc` file, if such // Called inside of an `extern "C" {}` block in the `.thunk.cc` file, if such
// a file is being generated. // a file is being generated.
virtual void InThunkCc(Context<FieldDescriptor> field) const {} virtual void InThunkCc(Context& ctx, const FieldDescriptor& field) const {}
}; };
class SingularScalar final : public AccessorGenerator { class SingularScalar final : public AccessorGenerator {
public: public:
~SingularScalar() override = default; ~SingularScalar() override = default;
void InMsgImpl(Context<FieldDescriptor> field) const override; void InMsgImpl(Context& ctx, const FieldDescriptor& field) const override;
void InExternC(Context<FieldDescriptor> field) const override; void InExternC(Context& ctx, const FieldDescriptor& field) const override;
void InThunkCc(Context<FieldDescriptor> field) const override; void InThunkCc(Context& ctx, const FieldDescriptor& field) const override;
}; };
class SingularString final : public AccessorGenerator { class SingularString final : public AccessorGenerator {
public: public:
~SingularString() override = default; ~SingularString() override = default;
void InMsgImpl(Context<FieldDescriptor> field) const override; void InMsgImpl(Context& ctx, const FieldDescriptor& field) const override;
void InExternC(Context<FieldDescriptor> field) const override; void InExternC(Context& ctx, const FieldDescriptor& field) const override;
void InThunkCc(Context<FieldDescriptor> field) const override; void InThunkCc(Context& ctx, const FieldDescriptor& field) const override;
}; };
class SingularMessage final : public AccessorGenerator { class SingularMessage final : public AccessorGenerator {
public: public:
~SingularMessage() override = default; ~SingularMessage() override = default;
void InMsgImpl(Context<FieldDescriptor> field) const override; void InMsgImpl(Context& ctx, const FieldDescriptor& field) const override;
void InExternC(Context<FieldDescriptor> field) const override; void InExternC(Context& ctx, const FieldDescriptor& field) const override;
void InThunkCc(Context<FieldDescriptor> field) const override; void InThunkCc(Context& ctx, const FieldDescriptor& field) const override;
}; };
class RepeatedScalar final : public AccessorGenerator { class RepeatedScalar final : public AccessorGenerator {
public: public:
~RepeatedScalar() override = default; ~RepeatedScalar() override = default;
void InMsgImpl(Context<FieldDescriptor> field) const override; void InMsgImpl(Context& ctx, const FieldDescriptor& field) const override;
void InExternC(Context<FieldDescriptor> field) const override; void InExternC(Context& ctx, const FieldDescriptor& field) const override;
void InThunkCc(Context<FieldDescriptor> field) const override; void InThunkCc(Context& ctx, const FieldDescriptor& field) const override;
}; };
class UnsupportedField final : public AccessorGenerator { class UnsupportedField final : public AccessorGenerator {
public: public:
explicit UnsupportedField(std::string reason) : reason_(std::move(reason)) {} explicit UnsupportedField(std::string reason) : reason_(std::move(reason)) {}
~UnsupportedField() override = default; ~UnsupportedField() override = default;
void InMsgImpl(Context<FieldDescriptor> field) const override; void InMsgImpl(Context& ctx, const FieldDescriptor& field) const override;
private: private:
std::string reason_; std::string reason_;
@ -109,9 +110,9 @@ class UnsupportedField final : public AccessorGenerator {
class Map final : public AccessorGenerator { class Map final : public AccessorGenerator {
public: public:
~Map() override = default; ~Map() override = default;
void InMsgImpl(Context<FieldDescriptor> field) const override; void InMsgImpl(Context& ctx, const FieldDescriptor& field) const override;
void InExternC(Context<FieldDescriptor> field) const override; void InExternC(Context& ctx, const FieldDescriptor& field) const override;
void InThunkCc(Context<FieldDescriptor> field) const override; void InThunkCc(Context& ctx, const FieldDescriptor& field) const override;
}; };
} // namespace rust } // namespace rust

@ -23,17 +23,16 @@ namespace rust {
namespace { namespace {
std::unique_ptr<AccessorGenerator> AccessorGeneratorFor( std::unique_ptr<AccessorGenerator> AccessorGeneratorFor(
Context<FieldDescriptor> field) { Context& ctx, const FieldDescriptor& field) {
const FieldDescriptor& desc = field.desc();
// TODO: We do not support [ctype=FOO] (used to set the field // TODO: We do not support [ctype=FOO] (used to set the field
// type in C++ to cord or string_piece) in V0.6 API. // type in C++ to cord or string_piece) in V0.6 API.
if (desc.options().has_ctype()) { if (field.options().has_ctype()) {
return std::make_unique<UnsupportedField>( return std::make_unique<UnsupportedField>(
"fields with ctype not supported"); "fields with ctype not supported");
} }
if (desc.is_map()) { if (field.is_map()) {
auto value_type = desc.message_type()->map_value()->type(); auto value_type = field.message_type()->map_value()->type();
switch (value_type) { switch (value_type) {
case FieldDescriptor::TYPE_BYTES: case FieldDescriptor::TYPE_BYTES:
case FieldDescriptor::TYPE_ENUM: case FieldDescriptor::TYPE_ENUM:
@ -46,7 +45,7 @@ std::unique_ptr<AccessorGenerator> AccessorGeneratorFor(
} }
} }
switch (desc.type()) { switch (field.type()) {
case FieldDescriptor::TYPE_INT32: case FieldDescriptor::TYPE_INT32:
case FieldDescriptor::TYPE_INT64: case FieldDescriptor::TYPE_INT64:
case FieldDescriptor::TYPE_FIXED32: case FieldDescriptor::TYPE_FIXED32:
@ -60,22 +59,22 @@ std::unique_ptr<AccessorGenerator> AccessorGeneratorFor(
case FieldDescriptor::TYPE_FLOAT: case FieldDescriptor::TYPE_FLOAT:
case FieldDescriptor::TYPE_DOUBLE: case FieldDescriptor::TYPE_DOUBLE:
case FieldDescriptor::TYPE_BOOL: case FieldDescriptor::TYPE_BOOL:
if (desc.is_repeated()) { if (field.is_repeated()) {
return std::make_unique<RepeatedScalar>(); return std::make_unique<RepeatedScalar>();
} }
return std::make_unique<SingularScalar>(); return std::make_unique<SingularScalar>();
case FieldDescriptor::TYPE_BYTES: case FieldDescriptor::TYPE_BYTES:
case FieldDescriptor::TYPE_STRING: case FieldDescriptor::TYPE_STRING:
if (desc.is_repeated()) { if (field.is_repeated()) {
return std::make_unique<UnsupportedField>("repeated str not supported"); return std::make_unique<UnsupportedField>("repeated str not supported");
} }
return std::make_unique<SingularString>(); return std::make_unique<SingularString>();
case FieldDescriptor::TYPE_MESSAGE: case FieldDescriptor::TYPE_MESSAGE:
if (desc.is_repeated()) { if (field.is_repeated()) {
return std::make_unique<UnsupportedField>("repeated msg not supported"); return std::make_unique<UnsupportedField>("repeated msg not supported");
} }
if (!field.generator_context().is_file_in_current_crate( if (!ctx.generator_context().is_file_in_current_crate(
desc.message_type()->file())) { *field.message_type()->file())) {
return std::make_unique<UnsupportedField>( return std::make_unique<UnsupportedField>(
"message fields that are imported from another proto_library" "message fields that are imported from another proto_library"
" (defined in a separate Rust crate) are not supported"); " (defined in a separate Rust crate) are not supported");
@ -89,21 +88,21 @@ std::unique_ptr<AccessorGenerator> AccessorGeneratorFor(
return std::make_unique<UnsupportedField>("group not supported"); return std::make_unique<UnsupportedField>("group not supported");
} }
ABSL_LOG(FATAL) << "Unexpected field type: " << desc.type(); ABSL_LOG(FATAL) << "Unexpected field type: " << field.type();
} }
} // namespace } // namespace
void GenerateAccessorMsgImpl(Context<FieldDescriptor> field) { void GenerateAccessorMsgImpl(Context& ctx, const FieldDescriptor& field) {
AccessorGeneratorFor(field)->GenerateMsgImpl(field); AccessorGeneratorFor(ctx, field)->GenerateMsgImpl(ctx, field);
} }
void GenerateAccessorExternC(Context<FieldDescriptor> field) { void GenerateAccessorExternC(Context& ctx, const FieldDescriptor& field) {
AccessorGeneratorFor(field)->GenerateExternC(field); AccessorGeneratorFor(ctx, field)->GenerateExternC(ctx, field);
} }
void GenerateAccessorThunkCc(Context<FieldDescriptor> field) { void GenerateAccessorThunkCc(Context& ctx, const FieldDescriptor& field) {
AccessorGeneratorFor(field)->GenerateThunkCc(field); AccessorGeneratorFor(ctx, field)->GenerateThunkCc(ctx, field);
} }
} // namespace rust } // namespace rust

@ -16,9 +16,9 @@ namespace protobuf {
namespace compiler { namespace compiler {
namespace rust { namespace rust {
void GenerateAccessorMsgImpl(Context<FieldDescriptor> field); void GenerateAccessorMsgImpl(Context& ctx, const FieldDescriptor& field);
void GenerateAccessorExternC(Context<FieldDescriptor> field); void GenerateAccessorExternC(Context& ctx, const FieldDescriptor& field);
void GenerateAccessorThunkCc(Context<FieldDescriptor> field); void GenerateAccessorThunkCc(Context& ctx, const FieldDescriptor& field);
} // namespace rust } // namespace rust
} // namespace compiler } // namespace compiler

@ -15,7 +15,6 @@
#include "absl/strings/escaping.h" #include "absl/strings/escaping.h"
#include "absl/strings/str_cat.h" #include "absl/strings/str_cat.h"
#include "absl/strings/str_format.h" #include "absl/strings/str_format.h"
#include "google/protobuf/compiler/rust/context.h"
#include "google/protobuf/descriptor.h" #include "google/protobuf/descriptor.h"
#include "google/protobuf/io/strtod.h" #include "google/protobuf/io/strtod.h"
@ -24,33 +23,32 @@ namespace protobuf {
namespace compiler { namespace compiler {
namespace rust { namespace rust {
std::string DefaultValue(Context<FieldDescriptor> field) { std::string DefaultValue(const FieldDescriptor& field) {
switch (field.desc().type()) { switch (field.type()) {
case FieldDescriptor::TYPE_DOUBLE: case FieldDescriptor::TYPE_DOUBLE:
if (std::isfinite(field.desc().default_value_double())) { if (std::isfinite(field.default_value_double())) {
return absl::StrCat(io::SimpleDtoa(field.desc().default_value_double()), return absl::StrCat(io::SimpleDtoa(field.default_value_double()),
"f64"); "f64");
} else if (std::isnan(field.desc().default_value_double())) { } else if (std::isnan(field.default_value_double())) {
return std::string("f64::NAN"); return std::string("f64::NAN");
} else if (field.desc().default_value_double() == } else if (field.default_value_double() ==
std::numeric_limits<double>::infinity()) { std::numeric_limits<double>::infinity()) {
return std::string("f64::INFINITY"); return std::string("f64::INFINITY");
} else if (field.desc().default_value_double() == } else if (field.default_value_double() ==
-std::numeric_limits<double>::infinity()) { -std::numeric_limits<double>::infinity()) {
return std::string("f64::NEG_INFINITY"); return std::string("f64::NEG_INFINITY");
} else { } else {
ABSL_LOG(FATAL) << "unreachable"; ABSL_LOG(FATAL) << "unreachable";
} }
case FieldDescriptor::TYPE_FLOAT: case FieldDescriptor::TYPE_FLOAT:
if (std::isfinite(field.desc().default_value_float())) { if (std::isfinite(field.default_value_float())) {
return absl::StrCat(io::SimpleFtoa(field.desc().default_value_float()), return absl::StrCat(io::SimpleFtoa(field.default_value_float()), "f32");
"f32"); } else if (std::isnan(field.default_value_float())) {
} else if (std::isnan(field.desc().default_value_float())) {
return std::string("f32::NAN"); return std::string("f32::NAN");
} else if (field.desc().default_value_float() == } else if (field.default_value_float() ==
std::numeric_limits<float>::infinity()) { std::numeric_limits<float>::infinity()) {
return std::string("f32::INFINITY"); return std::string("f32::INFINITY");
} else if (field.desc().default_value_float() == } else if (field.default_value_float() ==
-std::numeric_limits<float>::infinity()) { -std::numeric_limits<float>::infinity()) {
return std::string("f32::NEG_INFINITY"); return std::string("f32::NEG_INFINITY");
} else { } else {
@ -59,27 +57,27 @@ std::string DefaultValue(Context<FieldDescriptor> field) {
case FieldDescriptor::TYPE_INT32: case FieldDescriptor::TYPE_INT32:
case FieldDescriptor::TYPE_SFIXED32: case FieldDescriptor::TYPE_SFIXED32:
case FieldDescriptor::TYPE_SINT32: case FieldDescriptor::TYPE_SINT32:
return absl::StrFormat("%d", field.desc().default_value_int32()); return absl::StrFormat("%d", field.default_value_int32());
case FieldDescriptor::TYPE_INT64: case FieldDescriptor::TYPE_INT64:
case FieldDescriptor::TYPE_SFIXED64: case FieldDescriptor::TYPE_SFIXED64:
case FieldDescriptor::TYPE_SINT64: case FieldDescriptor::TYPE_SINT64:
return absl::StrFormat("%d", field.desc().default_value_int64()); return absl::StrFormat("%d", field.default_value_int64());
case FieldDescriptor::TYPE_FIXED64: case FieldDescriptor::TYPE_FIXED64:
case FieldDescriptor::TYPE_UINT64: case FieldDescriptor::TYPE_UINT64:
return absl::StrFormat("%u", field.desc().default_value_uint64()); return absl::StrFormat("%u", field.default_value_uint64());
case FieldDescriptor::TYPE_FIXED32: case FieldDescriptor::TYPE_FIXED32:
case FieldDescriptor::TYPE_UINT32: case FieldDescriptor::TYPE_UINT32:
return absl::StrFormat("%u", field.desc().default_value_uint32()); return absl::StrFormat("%u", field.default_value_uint32());
case FieldDescriptor::TYPE_BOOL: case FieldDescriptor::TYPE_BOOL:
return absl::StrFormat("%v", field.desc().default_value_bool()); return absl::StrFormat("%v", field.default_value_bool());
case FieldDescriptor::TYPE_STRING: case FieldDescriptor::TYPE_STRING:
case FieldDescriptor::TYPE_BYTES: case FieldDescriptor::TYPE_BYTES:
return absl::StrFormat( return absl::StrFormat("b\"%s\"",
"b\"%s\"", absl::CHexEscape(field.desc().default_value_string())); absl::CHexEscape(field.default_value_string()));
case FieldDescriptor::TYPE_GROUP: case FieldDescriptor::TYPE_GROUP:
case FieldDescriptor::TYPE_MESSAGE: case FieldDescriptor::TYPE_MESSAGE:
case FieldDescriptor::TYPE_ENUM: case FieldDescriptor::TYPE_ENUM:
ABSL_LOG(FATAL) << "Unsupported field type: " << field.desc().type_name(); ABSL_LOG(FATAL) << "Unsupported field type: " << field.type_name();
} }
ABSL_LOG(FATAL) << "unreachable"; ABSL_LOG(FATAL) << "unreachable";
} }

@ -10,7 +10,6 @@
#include <string> #include <string>
#include "google/protobuf/compiler/rust/context.h"
#include "google/protobuf/descriptor.h" #include "google/protobuf/descriptor.h"
namespace google { namespace google {
@ -23,7 +22,7 @@ namespace rust {
// Both strings and bytes are represented as a byte string literal, i.e. in the // Both strings and bytes are represented as a byte string literal, i.e. in the
// format `b"default value here"`. It is the caller's responsibility to convert // format `b"default value here"`. It is the caller's responsibility to convert
// the byte literal to an actual string, if needed. // the byte literal to an actual string, if needed.
std::string DefaultValue(Context<FieldDescriptor> field); std::string DefaultValue(const FieldDescriptor& field);
} // namespace rust } // namespace rust
} // namespace compiler } // namespace compiler

@ -17,19 +17,19 @@ namespace protobuf {
namespace compiler { namespace compiler {
namespace rust { namespace rust {
void Map::InMsgImpl(Context<FieldDescriptor> field) const { void Map::InMsgImpl(Context& ctx, const FieldDescriptor& field) const {
auto& key_type = *field.desc().message_type()->map_key(); auto& key_type = *field.message_type()->map_key();
auto& value_type = *field.desc().message_type()->map_value(); auto& value_type = *field.message_type()->map_value();
field.Emit({{"field", field.desc().name()}, ctx.Emit({{"field", field.name()},
{"Key", PrimitiveRsTypeName(key_type)}, {"Key", PrimitiveRsTypeName(key_type)},
{"Value", PrimitiveRsTypeName(value_type)}, {"Value", PrimitiveRsTypeName(value_type)},
{"getter_thunk", Thunk(field, "get")}, {"getter_thunk", Thunk(ctx, field, "get")},
{"getter_mut_thunk", Thunk(field, "get_mut")}, {"getter_mut_thunk", Thunk(ctx, field, "get_mut")},
{"getter", {"getter",
[&] { [&] {
if (field.is_upb()) { if (ctx.is_upb()) {
field.Emit({}, R"rs( ctx.Emit({}, R"rs(
pub fn r#$field$(&self) pub fn r#$field$(&self)
-> $pb$::View<'_, $pb$::Map<$Key$, $Value$>> { -> $pb$::View<'_, $pb$::Map<$Key$, $Value$>> {
let inner = unsafe { let inner = unsafe {
@ -44,8 +44,8 @@ void Map::InMsgImpl(Context<FieldDescriptor> field) const {
}); });
$pb$::MapView::from_inner($pbi$::Private, inner) $pb$::MapView::from_inner($pbi$::Private, inner)
})rs"); })rs");
} else { } else {
field.Emit({}, R"rs( ctx.Emit({}, R"rs(
pub fn r#$field$(&self) pub fn r#$field$(&self)
-> $pb$::View<'_, $pb$::Map<$Key$, $Value$>> { -> $pb$::View<'_, $pb$::Map<$Key$, $Value$>> {
let inner = $pbr$::MapInner { let inner = $pbr$::MapInner {
@ -55,12 +55,12 @@ void Map::InMsgImpl(Context<FieldDescriptor> field) const {
}; };
$pb$::MapView::from_inner($pbi$::Private, inner) $pb$::MapView::from_inner($pbi$::Private, inner)
})rs"); })rs");
} }
}}, }},
{"getter_mut", {"getter_mut",
[&] { [&] {
if (field.is_upb()) { if (ctx.is_upb()) {
field.Emit({}, R"rs( ctx.Emit({}, R"rs(
pub fn r#$field$_mut(&mut self) pub fn r#$field$_mut(&mut self)
-> $pb$::Mut<'_, $pb$::Map<$Key$, $Value$>> { -> $pb$::Mut<'_, $pb$::Map<$Key$, $Value$>> {
let raw = unsafe { let raw = unsafe {
@ -75,8 +75,8 @@ void Map::InMsgImpl(Context<FieldDescriptor> field) const {
}; };
$pb$::MapMut::from_inner($pbi$::Private, inner) $pb$::MapMut::from_inner($pbi$::Private, inner)
})rs"); })rs");
} else { } else {
field.Emit({}, R"rs( ctx.Emit({}, R"rs(
pub fn r#$field$_mut(&mut self) pub fn r#$field$_mut(&mut self)
-> $pb$::Mut<'_, $pb$::Map<$Key$, $Value$>> { -> $pb$::Mut<'_, $pb$::Map<$Key$, $Value$>> {
let inner = $pbr$::MapInner { let inner = $pbr$::MapInner {
@ -86,30 +86,30 @@ void Map::InMsgImpl(Context<FieldDescriptor> field) const {
}; };
$pb$::MapMut::from_inner($pbi$::Private, inner) $pb$::MapMut::from_inner($pbi$::Private, inner)
})rs"); })rs");
} }
}}}, }}},
R"rs( R"rs(
$getter$ $getter$
$getter_mut$ $getter_mut$
)rs"); )rs");
} }
void Map::InExternC(Context<FieldDescriptor> field) const { void Map::InExternC(Context& ctx, const FieldDescriptor& field) const {
field.Emit( ctx.Emit(
{ {
{"getter_thunk", Thunk(field, "get")}, {"getter_thunk", Thunk(ctx, field, "get")},
{"getter_mut_thunk", Thunk(field, "get_mut")}, {"getter_mut_thunk", Thunk(ctx, field, "get_mut")},
{"getter", {"getter",
[&] { [&] {
if (field.is_upb()) { if (ctx.is_upb()) {
field.Emit({}, R"rs( ctx.Emit({}, R"rs(
fn $getter_thunk$(raw_msg: $pbi$::RawMessage) fn $getter_thunk$(raw_msg: $pbi$::RawMessage)
-> Option<$pbi$::RawMap>; -> Option<$pbi$::RawMap>;
fn $getter_mut_thunk$(raw_msg: $pbi$::RawMessage, fn $getter_mut_thunk$(raw_msg: $pbi$::RawMessage,
arena: $pbi$::RawArena) -> $pbi$::RawMap; arena: $pbi$::RawArena) -> $pbi$::RawMap;
)rs"); )rs");
} else { } else {
field.Emit({}, R"rs( ctx.Emit({}, R"rs(
fn $getter_thunk$(msg: $pbi$::RawMessage) -> $pbi$::RawMap; fn $getter_thunk$(msg: $pbi$::RawMessage) -> $pbi$::RawMap;
fn $getter_mut_thunk$(msg: $pbi$::RawMessage,) -> $pbi$::RawMap; fn $getter_mut_thunk$(msg: $pbi$::RawMessage,) -> $pbi$::RawMap;
)rs"); )rs");
@ -121,20 +121,19 @@ void Map::InExternC(Context<FieldDescriptor> field) const {
)rs"); )rs");
} }
void Map::InThunkCc(Context<FieldDescriptor> field) const { void Map::InThunkCc(Context& ctx, const FieldDescriptor& field) const {
field.Emit( ctx.Emit(
{{"field", cpp::FieldName(&field.desc())}, {{"field", cpp::FieldName(&field)},
{"Key", cpp::PrimitiveTypeName( {"Key",
field.desc().message_type()->map_key()->cpp_type())}, cpp::PrimitiveTypeName(field.message_type()->map_key()->cpp_type())},
{"Value", cpp::PrimitiveTypeName( {"Value",
field.desc().message_type()->map_value()->cpp_type())}, cpp::PrimitiveTypeName(field.message_type()->map_value()->cpp_type())},
{"QualifiedMsg", {"QualifiedMsg", cpp::QualifiedClassName(field.containing_type())},
cpp::QualifiedClassName(field.desc().containing_type())}, {"getter_thunk", Thunk(ctx, field, "get")},
{"getter_thunk", Thunk(field, "get")}, {"getter_mut_thunk", Thunk(ctx, field, "get_mut")},
{"getter_mut_thunk", Thunk(field, "get_mut")},
{"impls", {"impls",
[&] { [&] {
field.Emit( ctx.Emit(
R"cc( R"cc(
const void* $getter_thunk$($QualifiedMsg$& msg) { const void* $getter_thunk$($QualifiedMsg$& msg) {
return &msg.$field$(); return &msg.$field$();

@ -17,15 +17,16 @@ namespace protobuf {
namespace compiler { namespace compiler {
namespace rust { namespace rust {
void RepeatedScalar::InMsgImpl(Context<FieldDescriptor> field) const { void RepeatedScalar::InMsgImpl(Context& ctx,
field.Emit({{"field", field.desc().name()}, const FieldDescriptor& field) const {
{"Scalar", PrimitiveRsTypeName(field.desc())}, ctx.Emit({{"field", field.name()},
{"getter_thunk", Thunk(field, "get")}, {"Scalar", PrimitiveRsTypeName(field)},
{"getter_mut_thunk", Thunk(field, "get_mut")}, {"getter_thunk", Thunk(ctx, field, "get")},
{"getter", {"getter_mut_thunk", Thunk(ctx, field, "get_mut")},
[&] { {"getter",
if (field.is_upb()) { [&] {
field.Emit({}, R"rs( if (ctx.is_upb()) {
ctx.Emit({}, R"rs(
pub fn r#$field$(&self) -> $pb$::RepeatedView<'_, $Scalar$> { pub fn r#$field$(&self) -> $pb$::RepeatedView<'_, $Scalar$> {
unsafe { unsafe {
$getter_thunk$( $getter_thunk$(
@ -40,8 +41,8 @@ void RepeatedScalar::InMsgImpl(Context<FieldDescriptor> field) const {
) )
} }
)rs"); )rs");
} else { } else {
field.Emit({}, R"rs( ctx.Emit({}, R"rs(
pub fn r#$field$(&self) -> $pb$::RepeatedView<'_, $Scalar$> { pub fn r#$field$(&self) -> $pb$::RepeatedView<'_, $Scalar$> {
unsafe { unsafe {
$pb$::RepeatedView::from_raw( $pb$::RepeatedView::from_raw(
@ -51,13 +52,13 @@ void RepeatedScalar::InMsgImpl(Context<FieldDescriptor> field) const {
} }
} }
)rs"); )rs");
} }
}}, }},
{"clearer_thunk", Thunk(field, "clear")}, {"clearer_thunk", Thunk(ctx, field, "clear")},
{"field_mutator_getter", {"field_mutator_getter",
[&] { [&] {
if (field.is_upb()) { if (ctx.is_upb()) {
field.Emit({}, R"rs( ctx.Emit({}, R"rs(
pub fn r#$field$_mut(&mut self) -> $pb$::RepeatedMut<'_, $Scalar$> { pub fn r#$field$_mut(&mut self) -> $pb$::RepeatedMut<'_, $Scalar$> {
unsafe { unsafe {
$pb$::RepeatedMut::from_inner( $pb$::RepeatedMut::from_inner(
@ -75,8 +76,8 @@ void RepeatedScalar::InMsgImpl(Context<FieldDescriptor> field) const {
} }
} }
)rs"); )rs");
} else { } else {
field.Emit({}, R"rs( ctx.Emit({}, R"rs(
pub fn r#$field$_mut(&mut self) -> $pb$::RepeatedMut<'_, $Scalar$> { pub fn r#$field$_mut(&mut self) -> $pb$::RepeatedMut<'_, $Scalar$> {
unsafe { unsafe {
$pb$::RepeatedMut::from_inner( $pb$::RepeatedMut::from_inner(
@ -89,22 +90,23 @@ void RepeatedScalar::InMsgImpl(Context<FieldDescriptor> field) const {
} }
} }
)rs"); )rs");
} }
}}}, }}},
R"rs( R"rs(
$getter$ $getter$
$field_mutator_getter$ $field_mutator_getter$
)rs"); )rs");
} }
void RepeatedScalar::InExternC(Context<FieldDescriptor> field) const { void RepeatedScalar::InExternC(Context& ctx,
field.Emit({{"Scalar", PrimitiveRsTypeName(field.desc())}, const FieldDescriptor& field) const {
{"getter_thunk", Thunk(field, "get")}, ctx.Emit({{"Scalar", PrimitiveRsTypeName(field)},
{"getter_mut_thunk", Thunk(field, "get_mut")}, {"getter_thunk", Thunk(ctx, field, "get")},
{"getter", {"getter_mut_thunk", Thunk(ctx, field, "get_mut")},
[&] { {"getter",
if (field.is_upb()) { [&] {
field.Emit(R"rs( if (ctx.is_upb()) {
ctx.Emit(R"rs(
fn $getter_mut_thunk$( fn $getter_mut_thunk$(
raw_msg: $pbi$::RawMessage, raw_msg: $pbi$::RawMessage,
size: *const usize, size: *const usize,
@ -116,44 +118,44 @@ void RepeatedScalar::InExternC(Context<FieldDescriptor> field) const {
size: *const usize, size: *const usize,
) -> Option<$pbi$::RawRepeatedField>; ) -> Option<$pbi$::RawRepeatedField>;
)rs"); )rs");
} else { } else {
field.Emit(R"rs( ctx.Emit(R"rs(
fn $getter_mut_thunk$(raw_msg: $pbi$::RawMessage) -> $pbi$::RawRepeatedField; fn $getter_mut_thunk$(raw_msg: $pbi$::RawMessage) -> $pbi$::RawRepeatedField;
fn $getter_thunk$(raw_msg: $pbi$::RawMessage) -> $pbi$::RawRepeatedField; fn $getter_thunk$(raw_msg: $pbi$::RawMessage) -> $pbi$::RawRepeatedField;
)rs"); )rs");
} }
}}, }},
{"clearer_thunk", Thunk(field, "clear")}}, {"clearer_thunk", Thunk(ctx, field, "clear")}},
R"rs( R"rs(
fn $clearer_thunk$(raw_msg: $pbi$::RawMessage); fn $clearer_thunk$(raw_msg: $pbi$::RawMessage);
$getter$ $getter$
)rs"); )rs");
} }
void RepeatedScalar::InThunkCc(Context<FieldDescriptor> field) const { void RepeatedScalar::InThunkCc(Context& ctx,
field.Emit({{"field", cpp::FieldName(&field.desc())}, const FieldDescriptor& field) const {
{"Scalar", cpp::PrimitiveTypeName(field.desc().cpp_type())}, ctx.Emit({{"field", cpp::FieldName(&field)},
{"QualifiedMsg", {"Scalar", cpp::PrimitiveTypeName(field.cpp_type())},
cpp::QualifiedClassName(field.desc().containing_type())}, {"QualifiedMsg", cpp::QualifiedClassName(field.containing_type())},
{"clearer_thunk", Thunk(field, "clear")}, {"clearer_thunk", Thunk(ctx, field, "clear")},
{"getter_thunk", Thunk(field, "get")}, {"getter_thunk", Thunk(ctx, field, "get")},
{"getter_mut_thunk", Thunk(field, "get_mut")}, {"getter_mut_thunk", Thunk(ctx, field, "get_mut")},
{"impls", {"impls",
[&] { [&] {
field.Emit( ctx.Emit(
R"cc( R"cc(
void $clearer_thunk$($QualifiedMsg$* msg) { void $clearer_thunk$($QualifiedMsg$* msg) {
msg->clear_$field$(); msg->clear_$field$();
} }
google::protobuf::RepeatedField<$Scalar$>* $getter_mut_thunk$($QualifiedMsg$* msg) { google::protobuf::RepeatedField<$Scalar$>* $getter_mut_thunk$($QualifiedMsg$* msg) {
return msg->mutable_$field$(); return msg->mutable_$field$();
} }
const google::protobuf::RepeatedField<$Scalar$>& $getter_thunk$($QualifiedMsg$& msg) { const google::protobuf::RepeatedField<$Scalar$>& $getter_thunk$($QualifiedMsg$& msg) {
return msg.$field$(); return msg.$field$();
} }
)cc"); )cc");
}}}, }}},
"$impls$"); "$impls$");
} }
} // namespace rust } // namespace rust

@ -17,23 +17,23 @@ namespace protobuf {
namespace compiler { namespace compiler {
namespace rust { namespace rust {
void SingularMessage::InMsgImpl(Context<FieldDescriptor> field) const { void SingularMessage::InMsgImpl(Context& ctx,
Context<Descriptor> d = field.WithDesc(field.desc().message_type()); const FieldDescriptor& field) const {
auto& msg = *field.message_type();
auto prefix = "crate::" + GetCrateRelativeQualifiedPath(ctx, msg);
auto prefix = "crate::" + GetCrateRelativeQualifiedPath(d); ctx.Emit(
field.Emit(
{ {
{"prefix", prefix}, {"prefix", prefix},
{"field", field.desc().name()}, {"field", field.name()},
{"getter_thunk", Thunk(field, "get")}, {"getter_thunk", Thunk(ctx, field, "get")},
{"getter_mut_thunk", Thunk(field, "get_mut")}, {"getter_mut_thunk", Thunk(ctx, field, "get_mut")},
{"clearer_thunk", Thunk(field, "clear")}, {"clearer_thunk", Thunk(ctx, field, "clear")},
{ {
"view_body", "view_body",
[&] { [&] {
if (field.is_upb()) { if (ctx.is_upb()) {
field.Emit({}, R"rs( ctx.Emit({}, R"rs(
let submsg = unsafe { $getter_thunk$(self.inner.msg) }; let submsg = unsafe { $getter_thunk$(self.inner.msg) };
// For upb, getters return null if the field is unset, so we need // For upb, getters return null if the field is unset, so we need
// to check for null and return the default instance manually. // to check for null and return the default instance manually.
@ -46,7 +46,7 @@ void SingularMessage::InMsgImpl(Context<FieldDescriptor> field) const {
} }
)rs"); )rs");
} else { } else {
field.Emit({}, R"rs( ctx.Emit({}, R"rs(
// For C++ kernel, getters automatically return the // For C++ kernel, getters automatically return the
// default_instance if the field is unset. // default_instance if the field is unset.
let submsg = unsafe { $getter_thunk$(self.inner.msg) }; let submsg = unsafe { $getter_thunk$(self.inner.msg) };
@ -57,15 +57,15 @@ void SingularMessage::InMsgImpl(Context<FieldDescriptor> field) const {
}, },
{"submessage_mut", {"submessage_mut",
[&] { [&] {
if (field.is_upb()) { if (ctx.is_upb()) {
field.Emit({}, R"rs( ctx.Emit({}, R"rs(
let submsg = unsafe { let submsg = unsafe {
$getter_mut_thunk$(self.inner.msg, self.inner.arena.raw()) $getter_mut_thunk$(self.inner.msg, self.inner.arena.raw())
}; };
$prefix$Mut::new($pbi$::Private, &mut self.inner, submsg) $prefix$Mut::new($pbi$::Private, &mut self.inner, submsg)
)rs"); )rs");
} else { } else {
field.Emit({}, R"rs( ctx.Emit({}, R"rs(
let submsg = unsafe { $getter_mut_thunk$(self.inner.msg) }; let submsg = unsafe { $getter_mut_thunk$(self.inner.msg) };
$prefix$Mut::new($pbi$::Private, &mut self.inner, submsg) $prefix$Mut::new($pbi$::Private, &mut self.inner, submsg)
)rs"); )rs");
@ -87,21 +87,22 @@ void SingularMessage::InMsgImpl(Context<FieldDescriptor> field) const {
)rs"); )rs");
} }
void SingularMessage::InExternC(Context<FieldDescriptor> field) const { void SingularMessage::InExternC(Context& ctx,
field.Emit( const FieldDescriptor& field) const {
ctx.Emit(
{ {
{"getter_thunk", Thunk(field, "get")}, {"getter_thunk", Thunk(ctx, field, "get")},
{"getter_mut_thunk", Thunk(field, "get_mut")}, {"getter_mut_thunk", Thunk(ctx, field, "get_mut")},
{"clearer_thunk", Thunk(field, "clear")}, {"clearer_thunk", Thunk(ctx, field, "clear")},
{"getter_mut", {"getter_mut",
[&] { [&] {
if (field.is_cpp()) { if (ctx.is_cpp()) {
field.Emit( ctx.Emit(
R"rs( R"rs(
fn $getter_mut_thunk$(raw_msg: $pbi$::RawMessage) fn $getter_mut_thunk$(raw_msg: $pbi$::RawMessage)
-> $pbi$::RawMessage;)rs"); -> $pbi$::RawMessage;)rs");
} else { } else {
field.Emit( ctx.Emit(
R"rs(fn $getter_mut_thunk$(raw_msg: $pbi$::RawMessage, R"rs(fn $getter_mut_thunk$(raw_msg: $pbi$::RawMessage,
arena: $pbi$::RawArena) arena: $pbi$::RawArena)
-> $pbi$::RawMessage;)rs"); -> $pbi$::RawMessage;)rs");
@ -109,13 +110,13 @@ void SingularMessage::InExternC(Context<FieldDescriptor> field) const {
}}, }},
{"ReturnType", {"ReturnType",
[&] { [&] {
if (field.is_cpp()) { if (ctx.is_cpp()) {
// guaranteed to have a nonnull submsg for the cpp kernel // guaranteed to have a nonnull submsg for the cpp kernel
field.Emit({}, "$pbi$::RawMessage;"); ctx.Emit({}, "$pbi$::RawMessage;");
} else { } else {
// upb kernel may return NULL for a submsg, we can detect this // upb kernel may return NULL for a submsg, we can detect this
// in terra rust if the option returned is None // in terra rust if the option returned is None
field.Emit({}, "Option<$pbi$::RawMessage>;"); ctx.Emit({}, "Option<$pbi$::RawMessage>;");
} }
}}, }},
}, },
@ -126,22 +127,22 @@ void SingularMessage::InExternC(Context<FieldDescriptor> field) const {
)rs"); )rs");
} }
void SingularMessage::InThunkCc(Context<FieldDescriptor> field) const { void SingularMessage::InThunkCc(Context& ctx,
field.Emit({{"QualifiedMsg", const FieldDescriptor& field) const {
cpp::QualifiedClassName(field.desc().containing_type())}, ctx.Emit({{"QualifiedMsg", cpp::QualifiedClassName(field.containing_type())},
{"getter_thunk", Thunk(field, "get")}, {"getter_thunk", Thunk(ctx, field, "get")},
{"getter_mut_thunk", Thunk(field, "get_mut")}, {"getter_mut_thunk", Thunk(ctx, field, "get_mut")},
{"clearer_thunk", Thunk(field, "clear")}, {"clearer_thunk", Thunk(ctx, field, "clear")},
{"field", cpp::FieldName(&field.desc())}}, {"field", cpp::FieldName(&field)}},
R"cc( R"cc(
const void* $getter_thunk$($QualifiedMsg$* msg) { const void* $getter_thunk$($QualifiedMsg$* msg) {
return static_cast<const void*>(&msg->$field$()); return static_cast<const void*>(&msg->$field$());
} }
void* $getter_mut_thunk$($QualifiedMsg$* msg) { void* $getter_mut_thunk$($QualifiedMsg$* msg) {
return static_cast<void*>(msg->mutable_$field$()); return static_cast<void*>(msg->mutable_$field$());
} }
void $clearer_thunk$($QualifiedMsg$* msg) { msg->clear_$field$(); } void $clearer_thunk$($QualifiedMsg$* msg) { msg->clear_$field$(); }
)cc"); )cc");
} }
} // namespace rust } // namespace rust

@ -18,16 +18,17 @@ namespace protobuf {
namespace compiler { namespace compiler {
namespace rust { namespace rust {
void SingularScalar::InMsgImpl(Context<FieldDescriptor> field) const { void SingularScalar::InMsgImpl(Context& ctx,
field.Emit( const FieldDescriptor& field) const {
ctx.Emit(
{ {
{"field", field.desc().name()}, {"field", field.name()},
{"Scalar", PrimitiveRsTypeName(field.desc())}, {"Scalar", PrimitiveRsTypeName(field)},
{"hazzer_thunk", Thunk(field, "has")}, {"hazzer_thunk", Thunk(ctx, field, "has")},
{"default_value", DefaultValue(field)}, {"default_value", DefaultValue(field)},
{"getter", {"getter",
[&] { [&] {
field.Emit({}, R"rs( ctx.Emit({}, R"rs(
pub fn r#$field$(&self) -> $Scalar$ { pub fn r#$field$(&self) -> $Scalar$ {
unsafe { $getter_thunk$(self.inner.msg) } unsafe { $getter_thunk$(self.inner.msg) }
} }
@ -35,9 +36,9 @@ void SingularScalar::InMsgImpl(Context<FieldDescriptor> field) const {
}}, }},
{"getter_opt", {"getter_opt",
[&] { [&] {
if (!field.desc().is_optional()) return; if (!field.is_optional()) return;
if (!field.desc().has_presence()) return; if (!field.has_presence()) return;
field.Emit({}, R"rs( ctx.Emit({}, R"rs(
pub fn r#$field$_opt(&self) -> $pb$::Optional<$Scalar$> { pub fn r#$field$_opt(&self) -> $pb$::Optional<$Scalar$> {
if !unsafe { $hazzer_thunk$(self.inner.msg) } { if !unsafe { $hazzer_thunk$(self.inner.msg) } {
return $pb$::Optional::Unset($default_value$); return $pb$::Optional::Unset($default_value$);
@ -47,13 +48,13 @@ void SingularScalar::InMsgImpl(Context<FieldDescriptor> field) const {
} }
)rs"); )rs");
}}, }},
{"getter_thunk", Thunk(field, "get")}, {"getter_thunk", Thunk(ctx, field, "get")},
{"setter_thunk", Thunk(field, "set")}, {"setter_thunk", Thunk(ctx, field, "set")},
{"clearer_thunk", Thunk(field, "clear")}, {"clearer_thunk", Thunk(ctx, field, "clear")},
{"field_mutator_getter", {"field_mutator_getter",
[&] { [&] {
if (field.desc().has_presence()) { if (field.has_presence()) {
field.Emit({}, R"rs( ctx.Emit({}, R"rs(
pub fn r#$field$_mut(&mut self) -> $pb$::FieldEntry<'_, $Scalar$> { pub fn r#$field$_mut(&mut self) -> $pb$::FieldEntry<'_, $Scalar$> {
static VTABLE: $pbi$::PrimitiveOptionalMutVTable<$Scalar$> = static VTABLE: $pbi$::PrimitiveOptionalMutVTable<$Scalar$> =
$pbi$::PrimitiveOptionalMutVTable::new( $pbi$::PrimitiveOptionalMutVTable::new(
@ -76,7 +77,7 @@ void SingularScalar::InMsgImpl(Context<FieldDescriptor> field) const {
} }
)rs"); )rs");
} else { } else {
field.Emit({}, R"rs( ctx.Emit({}, R"rs(
pub fn r#$field$_mut(&mut self) -> $pb$::Mut<'_, $Scalar$> { pub fn r#$field$_mut(&mut self) -> $pb$::Mut<'_, $Scalar$> {
static VTABLE: $pbi$::PrimitiveVTable<$Scalar$> = static VTABLE: $pbi$::PrimitiveVTable<$Scalar$> =
$pbi$::PrimitiveVTable::new( $pbi$::PrimitiveVTable::new(
@ -114,56 +115,57 @@ void SingularScalar::InMsgImpl(Context<FieldDescriptor> field) const {
)rs"); )rs");
} }
void SingularScalar::InExternC(Context<FieldDescriptor> field) const { void SingularScalar::InExternC(Context& ctx,
field.Emit({{"Scalar", PrimitiveRsTypeName(field.desc())}, const FieldDescriptor& field) const {
{"hazzer_thunk", Thunk(field, "has")}, ctx.Emit({{"Scalar", PrimitiveRsTypeName(field)},
{"getter_thunk", Thunk(field, "get")}, {"hazzer_thunk", Thunk(ctx, field, "has")},
{"setter_thunk", Thunk(field, "set")}, {"getter_thunk", Thunk(ctx, field, "get")},
{"clearer_thunk", Thunk(field, "clear")}, {"setter_thunk", Thunk(ctx, field, "set")},
{"hazzer_and_clearer", {"clearer_thunk", Thunk(ctx, field, "clear")},
[&] { {"hazzer_and_clearer",
if (field.desc().has_presence()) { [&] {
field.Emit( if (field.has_presence()) {
R"rs( ctx.Emit(
R"rs(
fn $hazzer_thunk$(raw_msg: $pbi$::RawMessage) -> bool; fn $hazzer_thunk$(raw_msg: $pbi$::RawMessage) -> bool;
fn $clearer_thunk$(raw_msg: $pbi$::RawMessage); fn $clearer_thunk$(raw_msg: $pbi$::RawMessage);
)rs"); )rs");
} }
}}}, }}},
R"rs( R"rs(
$hazzer_and_clearer$ $hazzer_and_clearer$
fn $getter_thunk$(raw_msg: $pbi$::RawMessage) -> $Scalar$; fn $getter_thunk$(raw_msg: $pbi$::RawMessage) -> $Scalar$;
fn $setter_thunk$(raw_msg: $pbi$::RawMessage, val: $Scalar$); fn $setter_thunk$(raw_msg: $pbi$::RawMessage, val: $Scalar$);
)rs"); )rs");
} }
void SingularScalar::InThunkCc(Context<FieldDescriptor> field) const { void SingularScalar::InThunkCc(Context& ctx,
field.Emit({{"field", cpp::FieldName(&field.desc())}, const FieldDescriptor& field) const {
{"Scalar", cpp::PrimitiveTypeName(field.desc().cpp_type())}, ctx.Emit({{"field", cpp::FieldName(&field)},
{"QualifiedMsg", {"Scalar", cpp::PrimitiveTypeName(field.cpp_type())},
cpp::QualifiedClassName(field.desc().containing_type())}, {"QualifiedMsg", cpp::QualifiedClassName(field.containing_type())},
{"hazzer_thunk", Thunk(field, "has")}, {"hazzer_thunk", Thunk(ctx, field, "has")},
{"getter_thunk", Thunk(field, "get")}, {"getter_thunk", Thunk(ctx, field, "get")},
{"setter_thunk", Thunk(field, "set")}, {"setter_thunk", Thunk(ctx, field, "set")},
{"clearer_thunk", Thunk(field, "clear")}, {"clearer_thunk", Thunk(ctx, field, "clear")},
{"hazzer_and_clearer", {"hazzer_and_clearer",
[&] { [&] {
if (field.desc().has_presence()) { if (field.has_presence()) {
field.Emit(R"cc( ctx.Emit(R"cc(
bool $hazzer_thunk$($QualifiedMsg$* msg) { bool $hazzer_thunk$($QualifiedMsg$* msg) {
return msg->has_$field$(); return msg->has_$field$();
} }
void $clearer_thunk$($QualifiedMsg$* msg) { msg->clear_$field$(); } void $clearer_thunk$($QualifiedMsg$* msg) { msg->clear_$field$(); }
)cc"); )cc");
}
}}},
R"cc(
$hazzer_and_clearer$;
$Scalar$ $getter_thunk$($QualifiedMsg$* msg) { return msg->$field$(); }
void $setter_thunk$($QualifiedMsg$* msg, $Scalar$ val) {
msg->set_$field$(val);
} }
)cc"); }}},
R"cc(
$hazzer_and_clearer$;
$Scalar$ $getter_thunk$($QualifiedMsg$* msg) { return msg->$field$(); }
void $setter_thunk$($QualifiedMsg$* msg, $Scalar$ val) {
msg->set_$field$(val);
}
)cc");
} }
} // namespace rust } // namespace rust

@ -20,24 +20,25 @@ namespace protobuf {
namespace compiler { namespace compiler {
namespace rust { namespace rust {
void SingularString::InMsgImpl(Context<FieldDescriptor> field) const { void SingularString::InMsgImpl(Context& ctx,
std::string hazzer_thunk = Thunk(field, "has"); const FieldDescriptor& field) const {
std::string getter_thunk = Thunk(field, "get"); std::string hazzer_thunk = Thunk(ctx, field, "has");
std::string setter_thunk = Thunk(field, "set"); std::string getter_thunk = Thunk(ctx, field, "get");
std::string proxied_type = PrimitiveRsTypeName(field.desc()); std::string setter_thunk = Thunk(ctx, field, "set");
std::string proxied_type = PrimitiveRsTypeName(field);
auto transform_view = [&] { auto transform_view = [&] {
if (field.desc().type() == FieldDescriptor::TYPE_STRING) { if (field.type() == FieldDescriptor::TYPE_STRING) {
field.Emit(R"rs( ctx.Emit(R"rs(
// SAFETY: The runtime doesn't require ProtoStr to be UTF-8. // SAFETY: The runtime doesn't require ProtoStr to be UTF-8.
unsafe { $pb$::ProtoStr::from_utf8_unchecked(view) } unsafe { $pb$::ProtoStr::from_utf8_unchecked(view) }
)rs"); )rs");
} else { } else {
field.Emit("view"); ctx.Emit("view");
} }
}; };
field.Emit( ctx.Emit(
{ {
{"field", field.desc().name()}, {"field", field.name()},
{"hazzer_thunk", hazzer_thunk}, {"hazzer_thunk", hazzer_thunk},
{"getter_thunk", getter_thunk}, {"getter_thunk", getter_thunk},
{"setter_thunk", setter_thunk}, {"setter_thunk", setter_thunk},
@ -45,12 +46,12 @@ void SingularString::InMsgImpl(Context<FieldDescriptor> field) const {
{"transform_view", transform_view}, {"transform_view", transform_view},
{"field_optional_getter", {"field_optional_getter",
[&] { [&] {
if (!field.desc().is_optional()) return; if (!field.is_optional()) return;
if (!field.desc().has_presence()) return; if (!field.has_presence()) return;
field.Emit({{"hazzer_thunk", hazzer_thunk}, ctx.Emit({{"hazzer_thunk", hazzer_thunk},
{"getter_thunk", getter_thunk}, {"getter_thunk", getter_thunk},
{"transform_view", transform_view}}, {"transform_view", transform_view}},
R"rs( R"rs(
pub fn $field$_opt(&self) -> $pb$::Optional<&$proxied_type$> { pub fn $field$_opt(&self) -> $pb$::Optional<&$proxied_type$> {
let view = unsafe { $getter_thunk$(self.inner.msg).as_ref() }; let view = unsafe { $getter_thunk$(self.inner.msg).as_ref() };
$pb$::Optional::new( $pb$::Optional::new(
@ -62,30 +63,29 @@ void SingularString::InMsgImpl(Context<FieldDescriptor> field) const {
}}, }},
{"field_mutator_getter", {"field_mutator_getter",
[&] { [&] {
if (field.desc().has_presence()) { if (field.has_presence()) {
field.Emit( ctx.Emit(
{ {
{"field", field.desc().name()}, {"field", field.name()},
{"proxied_type", proxied_type}, {"proxied_type", proxied_type},
{"default_val", DefaultValue(field)}, {"default_val", DefaultValue(field)},
{"view_type", proxied_type}, {"view_type", proxied_type},
{"transform_field_entry", {"transform_field_entry",
[&] { [&] {
if (field.desc().type() == if (field.type() == FieldDescriptor::TYPE_STRING) {
FieldDescriptor::TYPE_STRING) { ctx.Emit(R"rs(
field.Emit(R"rs(
$pb$::ProtoStrMut::field_entry_from_bytes( $pb$::ProtoStrMut::field_entry_from_bytes(
$pbi$::Private, out $pbi$::Private, out
) )
)rs"); )rs");
} else { } else {
field.Emit("out"); ctx.Emit("out");
} }
}}, }},
{"hazzer_thunk", hazzer_thunk}, {"hazzer_thunk", hazzer_thunk},
{"getter_thunk", getter_thunk}, {"getter_thunk", getter_thunk},
{"setter_thunk", setter_thunk}, {"setter_thunk", setter_thunk},
{"clearer_thunk", Thunk(field, "clear")}, {"clearer_thunk", Thunk(ctx, field, "clear")},
}, },
R"rs( R"rs(
pub fn $field$_mut(&mut self) -> $pb$::FieldEntry<'_, $proxied_type$> { pub fn $field$_mut(&mut self) -> $pb$::FieldEntry<'_, $proxied_type$> {
@ -112,11 +112,11 @@ void SingularString::InMsgImpl(Context<FieldDescriptor> field) const {
} }
)rs"); )rs");
} else { } else {
field.Emit({{"field", field.desc().name()}, ctx.Emit({{"field", field.name()},
{"proxied_type", proxied_type}, {"proxied_type", proxied_type},
{"getter_thunk", getter_thunk}, {"getter_thunk", getter_thunk},
{"setter_thunk", setter_thunk}}, {"setter_thunk", setter_thunk}},
R"rs( R"rs(
pub fn $field$_mut(&mut self) -> $pb$::Mut<'_, $proxied_type$> { pub fn $field$_mut(&mut self) -> $pb$::Mut<'_, $proxied_type$> {
static VTABLE: $pbi$::BytesMutVTable = unsafe { static VTABLE: $pbi$::BytesMutVTable = unsafe {
$pbi$::BytesMutVTable::new( $pbi$::BytesMutVTable::new(
@ -152,20 +152,21 @@ void SingularString::InMsgImpl(Context<FieldDescriptor> field) const {
)rs"); )rs");
} }
void SingularString::InExternC(Context<FieldDescriptor> field) const { void SingularString::InExternC(Context& ctx,
field.Emit({{"hazzer_thunk", Thunk(field, "has")}, const FieldDescriptor& field) const {
{"getter_thunk", Thunk(field, "get")}, ctx.Emit({{"hazzer_thunk", Thunk(ctx, field, "has")},
{"setter_thunk", Thunk(field, "set")}, {"getter_thunk", Thunk(ctx, field, "get")},
{"clearer_thunk", Thunk(field, "clear")}, {"setter_thunk", Thunk(ctx, field, "set")},
{"hazzer", {"clearer_thunk", Thunk(ctx, field, "clear")},
[&] { {"hazzer",
if (field.desc().has_presence()) { [&] {
field.Emit(R"rs( if (field.has_presence()) {
ctx.Emit(R"rs(
fn $hazzer_thunk$(raw_msg: $pbi$::RawMessage) -> bool; fn $hazzer_thunk$(raw_msg: $pbi$::RawMessage) -> bool;
)rs"); )rs");
} }
}}}, }}},
R"rs( R"rs(
$hazzer$ $hazzer$
fn $getter_thunk$(raw_msg: $pbi$::RawMessage) -> $pbi$::PtrAndLen; fn $getter_thunk$(raw_msg: $pbi$::RawMessage) -> $pbi$::PtrAndLen;
fn $setter_thunk$(raw_msg: $pbi$::RawMessage, val: $pbi$::PtrAndLen); fn $setter_thunk$(raw_msg: $pbi$::RawMessage, val: $pbi$::PtrAndLen);
@ -173,35 +174,35 @@ void SingularString::InExternC(Context<FieldDescriptor> field) const {
)rs"); )rs");
} }
void SingularString::InThunkCc(Context<FieldDescriptor> field) const { void SingularString::InThunkCc(Context& ctx,
field.Emit({{"field", cpp::FieldName(&field.desc())}, const FieldDescriptor& field) const {
{"QualifiedMsg", ctx.Emit({{"field", cpp::FieldName(&field)},
cpp::QualifiedClassName(field.desc().containing_type())}, {"QualifiedMsg", cpp::QualifiedClassName(field.containing_type())},
{"hazzer_thunk", Thunk(field, "has")}, {"hazzer_thunk", Thunk(ctx, field, "has")},
{"getter_thunk", Thunk(field, "get")}, {"getter_thunk", Thunk(ctx, field, "get")},
{"setter_thunk", Thunk(field, "set")}, {"setter_thunk", Thunk(ctx, field, "set")},
{"clearer_thunk", Thunk(field, "clear")}, {"clearer_thunk", Thunk(ctx, field, "clear")},
{"hazzer", {"hazzer",
[&] { [&] {
if (field.desc().has_presence()) { if (field.has_presence()) {
field.Emit(R"cc( ctx.Emit(R"cc(
bool $hazzer_thunk$($QualifiedMsg$* msg) { bool $hazzer_thunk$($QualifiedMsg$* msg) {
return msg->has_$field$(); return msg->has_$field$();
} }
void $clearer_thunk$($QualifiedMsg$* msg) { msg->clear_$field$(); } void $clearer_thunk$($QualifiedMsg$* msg) { msg->clear_$field$(); }
)cc"); )cc");
}
}}},
R"cc(
$hazzer$;
::google::protobuf::rust_internal::PtrAndLen $getter_thunk$($QualifiedMsg$* msg) {
absl::string_view val = msg->$field$();
return ::google::protobuf::rust_internal::PtrAndLen(val.data(), val.size());
}
void $setter_thunk$($QualifiedMsg$* msg, ::google::protobuf::rust_internal::PtrAndLen s) {
msg->set_$field$(absl::string_view(s.ptr, s.len));
} }
)cc"); }}},
R"cc(
$hazzer$;
::google::protobuf::rust_internal::PtrAndLen $getter_thunk$($QualifiedMsg$* msg) {
absl::string_view val = msg->$field$();
return ::google::protobuf::rust_internal::PtrAndLen(val.data(), val.size());
}
void $setter_thunk$($QualifiedMsg$* msg, ::google::protobuf::rust_internal::PtrAndLen s) {
msg->set_$field$(absl::string_view(s.ptr, s.len));
}
)cc");
} }
} // namespace rust } // namespace rust

@ -15,11 +15,12 @@ namespace protobuf {
namespace compiler { namespace compiler {
namespace rust { namespace rust {
void UnsupportedField::InMsgImpl(Context<FieldDescriptor> field) const { void UnsupportedField::InMsgImpl(Context& ctx,
field.Emit({{"reason", reason_}}, R"rs( const FieldDescriptor& field) const {
ctx.Emit({{"reason", reason_}}, R"rs(
// Unsupported! :( Reason: $reason$ // Unsupported! :( Reason: $reason$
)rs"); )rs");
field.printer().PrintRaw("\n"); ctx.printer().PrintRaw("\n");
} }
} // namespace rust } // namespace rust

@ -68,13 +68,12 @@ absl::StatusOr<Options> Options::Parse(absl::string_view param) {
return opts; return opts;
} }
bool IsInCurrentlyGeneratingCrate(Context<FileDescriptor> file) { bool IsInCurrentlyGeneratingCrate(Context& ctx, const FileDescriptor& file) {
return file.generator_context().is_file_in_current_crate(&file.desc()); return ctx.generator_context().is_file_in_current_crate(file);
} }
bool IsInCurrentlyGeneratingCrate(Context<Descriptor> message) { bool IsInCurrentlyGeneratingCrate(Context& ctx, const Descriptor& message) {
return message.generator_context().is_file_in_current_crate( return IsInCurrentlyGeneratingCrate(ctx, *message.file());
message.desc().file());
} }
} // namespace rust } // namespace rust

@ -53,14 +53,14 @@ class RustGeneratorContext {
const std::vector<const FileDescriptor*>* files_in_current_crate) const std::vector<const FileDescriptor*>* files_in_current_crate)
: files_in_current_crate_(*files_in_current_crate) {} : files_in_current_crate_(*files_in_current_crate) {}
const FileDescriptor* primary_file() const { const FileDescriptor& primary_file() const {
return files_in_current_crate_.front(); return *files_in_current_crate_.front();
} }
bool is_file_in_current_crate(const FileDescriptor* f) const { bool is_file_in_current_crate(const FileDescriptor& f) const {
return std::find(files_in_current_crate_.begin(), return std::find(files_in_current_crate_.begin(),
files_in_current_crate_.end(), files_in_current_crate_.end(),
f) != files_in_current_crate_.end(); &f) != files_in_current_crate_.end();
} }
private: private:
@ -68,26 +68,20 @@ class RustGeneratorContext {
}; };
// A context for generating a particular kind of definition. // A context for generating a particular kind of definition.
// This type acts as an options struct (as in go/totw/173) for most of the
// generator.
//
// `Descriptor` is the type of a descriptor.h class relevant for the current
// context.
template <typename Descriptor>
class Context { class Context {
public: public:
Context(const Options* opts, const Descriptor* desc, Context(const Options* opts,
const RustGeneratorContext* rust_generator_context, const RustGeneratorContext* rust_generator_context,
io::Printer* printer) io::Printer* printer)
: opts_(opts), : opts_(opts),
desc_(desc),
rust_generator_context_(rust_generator_context), rust_generator_context_(rust_generator_context),
printer_(printer) {} printer_(printer) {}
Context(const Context&) = default; Context(const Context&) = delete;
Context& operator=(const Context&) = default; Context& operator=(const Context&) = delete;
Context(Context&&) = default;
Context& operator=(Context&&) = default;
const Descriptor& desc() const { return *desc_; }
const Options& opts() const { return *opts_; } const Options& opts() const { return *opts_; }
const RustGeneratorContext& generator_context() const { const RustGeneratorContext& generator_context() const {
return *rust_generator_context_; return *rust_generator_context_;
@ -99,19 +93,8 @@ class Context {
// NOTE: prefer ctx.Emit() over ctx.printer().Emit(); // NOTE: prefer ctx.Emit() over ctx.printer().Emit();
io::Printer& printer() const { return *printer_; } io::Printer& printer() const { return *printer_; }
// Creates a new context over a different descriptor.
template <typename D>
Context<D> WithDesc(const D& desc) const {
return Context<D>(opts_, &desc, rust_generator_context_, printer_);
}
template <typename D>
Context<D> WithDesc(const D* desc) const {
return Context<D>(opts_, desc, rust_generator_context_, printer_);
}
Context WithPrinter(io::Printer* printer) const { Context WithPrinter(io::Printer* printer) const {
return Context(opts_, desc_, rust_generator_context_, printer); return Context(opts_, rust_generator_context_, printer);
} }
// Forwards to Emit(), which will likely be called all the time. // Forwards to Emit(), which will likely be called all the time.
@ -128,13 +111,12 @@ class Context {
private: private:
const Options* opts_; const Options* opts_;
const Descriptor* desc_;
const RustGeneratorContext* rust_generator_context_; const RustGeneratorContext* rust_generator_context_;
io::Printer* printer_; io::Printer* printer_;
}; };
bool IsInCurrentlyGeneratingCrate(Context<FileDescriptor> file); bool IsInCurrentlyGeneratingCrate(Context& ctx, const FileDescriptor& file);
bool IsInCurrentlyGeneratingCrate(Context<Descriptor> message); bool IsInCurrentlyGeneratingCrate(Context& ctx, const Descriptor& message);
} // namespace rust } // namespace rust
} // namespace compiler } // namespace compiler

@ -48,12 +48,11 @@ namespace {
// pub mod submodule { // pub mod submodule {
// pub mod separator { // pub mod separator {
// ``` // ```
void EmitOpeningOfPackageModules(absl::string_view pkg, void EmitOpeningOfPackageModules(Context& ctx, absl::string_view pkg) {
Context<FileDescriptor> file) {
if (pkg.empty()) return; if (pkg.empty()) return;
for (absl::string_view segment : absl::StrSplit(pkg, '.')) { for (absl::string_view segment : absl::StrSplit(pkg, '.')) {
file.Emit({{"segment", segment}}, ctx.Emit({{"segment", segment}},
R"rs( R"rs(
pub mod $segment$ { pub mod $segment$ {
)rs"); )rs");
} }
@ -70,14 +69,13 @@ void EmitOpeningOfPackageModules(absl::string_view pkg,
// } // mod uses // } // mod uses
// } // mod package // } // mod package
// ``` // ```
void EmitClosingOfPackageModules(absl::string_view pkg, void EmitClosingOfPackageModules(Context& ctx, absl::string_view pkg) {
Context<FileDescriptor> file) {
if (pkg.empty()) return; if (pkg.empty()) return;
std::vector<absl::string_view> segments = absl::StrSplit(pkg, '.'); std::vector<absl::string_view> segments = absl::StrSplit(pkg, '.');
absl::c_reverse(segments); absl::c_reverse(segments);
for (absl::string_view segment : segments) { for (absl::string_view segment : segments) {
file.Emit({{"segment", segment}}, R"rs( ctx.Emit({{"segment", segment}}, R"rs(
} // mod $segment$ } // mod $segment$
)rs"); )rs");
} }
@ -87,14 +85,13 @@ void EmitClosingOfPackageModules(absl::string_view pkg,
// `non_primary_src` into the `primary_file`. // `non_primary_src` into the `primary_file`.
// //
// `non_primary_src` has to be a non-primary src of the current `proto_library`. // `non_primary_src` has to be a non-primary src of the current `proto_library`.
void EmitPubUseOfOwnMessages(Context<FileDescriptor>& primary_file, void EmitPubUseOfOwnMessages(Context& ctx, const FileDescriptor& primary_file,
const Context<FileDescriptor>& non_primary_src) { const FileDescriptor& non_primary_src) {
for (int i = 0; i < non_primary_src.desc().message_type_count(); ++i) { for (int i = 0; i < non_primary_src.message_type_count(); ++i) {
auto msg = primary_file.WithDesc(non_primary_src.desc().message_type(i)); auto& msg = *non_primary_src.message_type(i);
auto mod = RustInternalModuleName(non_primary_src); auto mod = RustInternalModuleName(ctx, non_primary_src);
auto name = msg.desc().name(); ctx.Emit({{"mod", mod}, {"Msg", msg.name()}},
primary_file.Emit({{"mod", mod}, {"Msg", name}}, R"rs(
R"rs(
pub use crate::$mod$::$Msg$; pub use crate::$mod$::$Msg$;
// TODO Address use for imported crates // TODO Address use for imported crates
pub use crate::$mod$::$Msg$View; pub use crate::$mod$::$Msg$View;
@ -109,14 +106,15 @@ void EmitPubUseOfOwnMessages(Context<FileDescriptor>& primary_file,
// //
// `dep` is a primary src of a dependency of the current `proto_library`. // `dep` is a primary src of a dependency of the current `proto_library`.
// TODO: Add support for public import of non-primary srcs of deps. // TODO: Add support for public import of non-primary srcs of deps.
void EmitPubUseForImportedMessages(Context<FileDescriptor>& primary_file, void EmitPubUseForImportedMessages(Context& ctx,
const Context<FileDescriptor>& dep) { const FileDescriptor& primary_file,
std::string crate_name = GetCrateName(dep); const FileDescriptor& dep) {
for (int i = 0; i < dep.desc().message_type_count(); ++i) { std::string crate_name = GetCrateName(ctx, dep);
auto msg = primary_file.WithDesc(dep.desc().message_type(i)); for (int i = 0; i < dep.message_type_count(); ++i) {
auto path = GetCrateRelativeQualifiedPath(msg); auto& msg = *dep.message_type(i);
primary_file.Emit({{"crate", crate_name}, {"pkg::Msg", path}}, auto path = GetCrateRelativeQualifiedPath(ctx, msg);
R"rs( ctx.Emit({{"crate", crate_name}, {"pkg::Msg", path}},
R"rs(
pub use $crate$::$pkg::Msg$; pub use $crate$::$pkg::Msg$;
pub use $crate$::$pkg::Msg$View; pub use $crate$::$pkg::Msg$View;
)rs"); )rs");
@ -124,9 +122,9 @@ void EmitPubUseForImportedMessages(Context<FileDescriptor>& primary_file,
} }
// Emits all public imports of the current file // Emits all public imports of the current file
void EmitPublicImports(Context<FileDescriptor>& primary_file) { void EmitPublicImports(Context& ctx, const FileDescriptor& primary_file) {
for (int i = 0; i < primary_file.desc().public_dependency_count(); ++i) { for (int i = 0; i < primary_file.public_dependency_count(); ++i) {
auto dep_file = primary_file.desc().public_dependency(i); auto& dep_file = *primary_file.public_dependency(i);
// If the publicly imported file is a src of the current `proto_library` // If the publicly imported file is a src of the current `proto_library`
// we don't need to emit `pub use` here, we already do it for all srcs in // we don't need to emit `pub use` here, we already do it for all srcs in
// RustGenerator::Generate. In other words, all srcs are implicitly publicly // RustGenerator::Generate. In other words, all srcs are implicitly publicly
@ -134,30 +132,29 @@ void EmitPublicImports(Context<FileDescriptor>& primary_file) {
// TODO: Handle the case where a non-primary src with the same // TODO: Handle the case where a non-primary src with the same
// declared package as the primary src publicly imports a file that the // declared package as the primary src publicly imports a file that the
// primary doesn't. // primary doesn't.
auto dep = primary_file.WithDesc(dep_file); if (IsInCurrentlyGeneratingCrate(ctx, dep_file)) {
if (IsInCurrentlyGeneratingCrate(dep)) {
return; return;
} }
EmitPubUseForImportedMessages(primary_file, dep); EmitPubUseForImportedMessages(ctx, primary_file, dep_file);
} }
} }
// Emits submodule declarations so `rustc` can find non primary sources from the // Emits submodule declarations so `rustc` can find non primary sources from the
// primary file. // primary file.
void DeclareSubmodulesForNonPrimarySrcs( void DeclareSubmodulesForNonPrimarySrcs(
Context<FileDescriptor>& primary_file, Context& ctx, const FileDescriptor& primary_file,
absl::Span<const Context<FileDescriptor>> non_primary_srcs) { absl::Span<const FileDescriptor* const> non_primary_srcs) {
std::string primary_file_path = GetRsFile(primary_file); std::string primary_file_path = GetRsFile(ctx, primary_file);
RelativePath primary_relpath(primary_file_path); RelativePath primary_relpath(primary_file_path);
for (const auto& non_primary_src : non_primary_srcs) { for (const FileDescriptor* non_primary_src : non_primary_srcs) {
std::string non_primary_file_path = GetRsFile(non_primary_src); std::string non_primary_file_path = GetRsFile(ctx, *non_primary_src);
std::string relative_mod_path = std::string relative_mod_path =
primary_relpath.Relative(RelativePath(non_primary_file_path)); primary_relpath.Relative(RelativePath(non_primary_file_path));
primary_file.Emit({{"file_path", relative_mod_path}, ctx.Emit({{"file_path", relative_mod_path},
{"foo", primary_file_path}, {"foo", primary_file_path},
{"bar", non_primary_file_path}, {"bar", non_primary_file_path},
{"mod_name", RustInternalModuleName(non_primary_src)}}, {"mod_name", RustInternalModuleName(ctx, *non_primary_src)}},
R"rs( R"rs(
#[path="$file_path$"] #[path="$file_path$"]
pub mod $mod_name$; pub mod $mod_name$;
)rs"); )rs");
@ -169,33 +166,32 @@ void DeclareSubmodulesForNonPrimarySrcs(
// //
// Returns the non-primary sources that should be reexported from the package of // Returns the non-primary sources that should be reexported from the package of
// the primary file. // the primary file.
std::vector<const Context<FileDescriptor>*> ReexportMessagesFromSubmodules( std::vector<const FileDescriptor*> ReexportMessagesFromSubmodules(
Context<FileDescriptor>& primary_file, Context& ctx, const FileDescriptor& primary_file,
absl::Span<const Context<FileDescriptor>> non_primary_srcs) { absl::Span<const FileDescriptor* const> non_primary_srcs) {
absl::btree_map<absl::string_view, absl::btree_map<absl::string_view, std::vector<const FileDescriptor*>>
std::vector<const Context<FileDescriptor>*>>
packages; packages;
for (const Context<FileDescriptor>& ctx : non_primary_srcs) { for (const FileDescriptor* file : non_primary_srcs) {
packages[ctx.desc().package()].push_back(&ctx); packages[file->package()].push_back(file);
} }
for (const auto& pair : packages) { for (const auto& pair : packages) {
// We will deal with messages for the package of the primary file later. // We will deal with messages for the package of the primary file later.
auto fds = pair.second; auto fds = pair.second;
absl::string_view package = fds[0]->desc().package(); absl::string_view package = fds[0]->package();
if (package == primary_file.desc().package()) continue; if (package == primary_file.package()) continue;
EmitOpeningOfPackageModules(package, primary_file); EmitOpeningOfPackageModules(ctx, package);
for (const Context<FileDescriptor>* c : fds) { for (const FileDescriptor* c : fds) {
EmitPubUseOfOwnMessages(primary_file, *c); EmitPubUseOfOwnMessages(ctx, primary_file, *c);
} }
EmitClosingOfPackageModules(package, primary_file); EmitClosingOfPackageModules(ctx, package);
} }
return packages[primary_file.desc().package()]; return packages[primary_file.package()];
} }
} // namespace } // namespace
bool RustGenerator::Generate(const FileDescriptor* file_desc, bool RustGenerator::Generate(const FileDescriptor* file,
const std::string& parameter, const std::string& parameter,
GeneratorContext* generator_context, GeneratorContext* generator_context,
std::string* error) const { std::string* error) const {
@ -210,15 +206,15 @@ bool RustGenerator::Generate(const FileDescriptor* file_desc,
RustGeneratorContext rust_generator_context(&files_in_current_crate); RustGeneratorContext rust_generator_context(&files_in_current_crate);
Context<FileDescriptor> file(&*opts, file_desc, &rust_generator_context, Context ctx_without_printer(&*opts, &rust_generator_context, nullptr);
nullptr);
auto outfile = absl::WrapUnique(generator_context->Open(GetRsFile(file))); auto outfile = absl::WrapUnique(
generator_context->Open(GetRsFile(ctx_without_printer, *file)));
io::Printer printer(outfile.get()); io::Printer printer(outfile.get());
file = file.WithPrinter(&printer); Context ctx = ctx_without_printer.WithPrinter(&printer);
// Convenience shorthands for common symbols. // Convenience shorthands for common symbols.
auto v = file.printer().WithVars({ auto v = ctx.printer().WithVars({
{"std", "::__std"}, {"std", "::__std"},
{"pb", "::__pb"}, {"pb", "::__pb"},
{"pbi", "::__pb::__internal"}, {"pbi", "::__pb::__internal"},
@ -227,67 +223,66 @@ bool RustGenerator::Generate(const FileDescriptor* file_desc,
{"Phantom", "::__std::marker::PhantomData"}, {"Phantom", "::__std::marker::PhantomData"},
}); });
file.Emit({{"kernel", KernelRsName(file.opts().kernel)}}, R"rs( ctx.Emit({{"kernel", KernelRsName(ctx.opts().kernel)}}, R"rs(
extern crate protobuf_$kernel$ as __pb; extern crate protobuf_$kernel$ as __pb;
extern crate std as __std; extern crate std as __std;
)rs"); )rs");
std::vector<Context<FileDescriptor>> file_contexts; std::vector<const FileDescriptor*> file_contexts;
for (const FileDescriptor* f : files_in_current_crate) { for (const FileDescriptor* f : files_in_current_crate) {
file_contexts.push_back(file.WithDesc(*f)); file_contexts.push_back(f);
} }
// Generating the primary file? // Generating the primary file?
if (file_desc == rust_generator_context.primary_file()) { if (file == &rust_generator_context.primary_file()) {
auto non_primary_srcs = absl::MakeConstSpan(file_contexts).subspan(1); auto non_primary_srcs = absl::MakeConstSpan(file_contexts).subspan(1);
DeclareSubmodulesForNonPrimarySrcs(file, non_primary_srcs); DeclareSubmodulesForNonPrimarySrcs(ctx, *file, non_primary_srcs);
std::vector<const Context<FileDescriptor>*> std::vector<const FileDescriptor*> non_primary_srcs_in_primary_package =
non_primary_srcs_in_primary_package = ReexportMessagesFromSubmodules(ctx, *file, non_primary_srcs);
ReexportMessagesFromSubmodules(file, non_primary_srcs);
EmitOpeningOfPackageModules(file.desc().package(), file); EmitOpeningOfPackageModules(ctx, file->package());
for (const Context<FileDescriptor>* non_primary_file : for (const FileDescriptor* non_primary_file :
non_primary_srcs_in_primary_package) { non_primary_srcs_in_primary_package) {
EmitPubUseOfOwnMessages(file, *non_primary_file); EmitPubUseOfOwnMessages(ctx, *file, *non_primary_file);
} }
} }
EmitPublicImports(file); EmitPublicImports(ctx, *file);
std::unique_ptr<io::ZeroCopyOutputStream> thunks_cc; std::unique_ptr<io::ZeroCopyOutputStream> thunks_cc;
std::unique_ptr<io::Printer> thunks_printer; std::unique_ptr<io::Printer> thunks_printer;
if (file.is_cpp()) { if (ctx.is_cpp()) {
thunks_cc.reset(generator_context->Open(GetThunkCcFile(file))); thunks_cc.reset(generator_context->Open(GetThunkCcFile(ctx, *file)));
thunks_printer = std::make_unique<io::Printer>(thunks_cc.get()); thunks_printer = std::make_unique<io::Printer>(thunks_cc.get());
thunks_printer->Emit({{"proto_h", GetHeaderFile(file)}}, thunks_printer->Emit({{"proto_h", GetHeaderFile(ctx, *file)}},
R"cc( R"cc(
#include "$proto_h$" #include "$proto_h$"
#include "google/protobuf/rust/cpp_kernel/cpp_api.h" #include "google/protobuf/rust/cpp_kernel/cpp_api.h"
)cc"); )cc");
} }
for (int i = 0; i < file.desc().message_type_count(); ++i) { for (int i = 0; i < file->message_type_count(); ++i) {
auto msg = file.WithDesc(file.desc().message_type(i)); auto& msg = *file->message_type(i);
GenerateRs(msg); GenerateRs(ctx, msg);
msg.printer().PrintRaw("\n"); ctx.printer().PrintRaw("\n");
if (file.is_cpp()) { if (ctx.is_cpp()) {
auto thunks_msg = msg.WithPrinter(thunks_printer.get()); auto thunks_ctx = ctx.WithPrinter(thunks_printer.get());
thunks_msg.Emit({{"Msg", msg.desc().full_name()}}, R"cc( thunks_ctx.Emit({{"Msg", msg.full_name()}}, R"cc(
// $Msg$ // $Msg$
)cc"); )cc");
GenerateThunksCc(thunks_msg); GenerateThunksCc(thunks_ctx, msg);
thunks_msg.printer().PrintRaw("\n"); thunks_ctx.printer().PrintRaw("\n");
} }
} }
if (file_desc == files_in_current_crate.front()) { if (file == files_in_current_crate.front()) {
EmitClosingOfPackageModules(file.desc().package(), file); EmitClosingOfPackageModules(ctx, file->package());
} }
return true; return true;
} }

@ -24,16 +24,16 @@ namespace compiler {
namespace rust { namespace rust {
namespace { namespace {
void MessageNew(Context<Descriptor> msg) { void MessageNew(Context& ctx, const Descriptor& msg) {
switch (msg.opts().kernel) { switch (ctx.opts().kernel) {
case Kernel::kCpp: case Kernel::kCpp:
msg.Emit({{"new_thunk", Thunk(msg, "new")}}, R"rs( ctx.Emit({{"new_thunk", Thunk(ctx, msg, "new")}}, R"rs(
Self { inner: $pbr$::MessageInner { msg: unsafe { $new_thunk$() } } } Self { inner: $pbr$::MessageInner { msg: unsafe { $new_thunk$() } } }
)rs"); )rs");
return; return;
case Kernel::kUpb: case Kernel::kUpb:
msg.Emit({{"new_thunk", Thunk(msg, "new")}}, R"rs( ctx.Emit({{"new_thunk", Thunk(ctx, msg, "new")}}, R"rs(
let arena = $pbr$::Arena::new(); let arena = $pbr$::Arena::new();
Self { Self {
inner: $pbr$::MessageInner { inner: $pbr$::MessageInner {
@ -48,16 +48,16 @@ void MessageNew(Context<Descriptor> msg) {
ABSL_LOG(FATAL) << "unreachable"; ABSL_LOG(FATAL) << "unreachable";
} }
void MessageSerialize(Context<Descriptor> msg) { void MessageSerialize(Context& ctx, const Descriptor& msg) {
switch (msg.opts().kernel) { switch (ctx.opts().kernel) {
case Kernel::kCpp: case Kernel::kCpp:
msg.Emit({{"serialize_thunk", Thunk(msg, "serialize")}}, R"rs( ctx.Emit({{"serialize_thunk", Thunk(ctx, msg, "serialize")}}, R"rs(
unsafe { $serialize_thunk$(self.inner.msg) } unsafe { $serialize_thunk$(self.inner.msg) }
)rs"); )rs");
return; return;
case Kernel::kUpb: case Kernel::kUpb:
msg.Emit({{"serialize_thunk", Thunk(msg, "serialize")}}, R"rs( ctx.Emit({{"serialize_thunk", Thunk(ctx, msg, "serialize")}}, R"rs(
let arena = $pbr$::Arena::new(); let arena = $pbr$::Arena::new();
let mut len = 0; let mut len = 0;
unsafe { unsafe {
@ -71,12 +71,12 @@ void MessageSerialize(Context<Descriptor> msg) {
ABSL_LOG(FATAL) << "unreachable"; ABSL_LOG(FATAL) << "unreachable";
} }
void MessageDeserialize(Context<Descriptor> msg) { void MessageDeserialize(Context& ctx, const Descriptor& msg) {
switch (msg.opts().kernel) { switch (ctx.opts().kernel) {
case Kernel::kCpp: case Kernel::kCpp:
msg.Emit( ctx.Emit(
{ {
{"deserialize_thunk", Thunk(msg, "deserialize")}, {"deserialize_thunk", Thunk(ctx, msg, "deserialize")},
}, },
R"rs( R"rs(
let success = unsafe { let success = unsafe {
@ -92,7 +92,7 @@ void MessageDeserialize(Context<Descriptor> msg) {
return; return;
case Kernel::kUpb: case Kernel::kUpb:
msg.Emit({{"deserialize_thunk", Thunk(msg, "parse")}}, R"rs( ctx.Emit({{"deserialize_thunk", Thunk(ctx, msg, "parse")}}, R"rs(
let arena = $pbr$::Arena::new(); let arena = $pbr$::Arena::new();
let msg = unsafe { let msg = unsafe {
$deserialize_thunk$(data.as_ptr(), data.len(), arena.raw()) $deserialize_thunk$(data.as_ptr(), data.len(), arena.raw())
@ -115,15 +115,15 @@ void MessageDeserialize(Context<Descriptor> msg) {
ABSL_LOG(FATAL) << "unreachable"; ABSL_LOG(FATAL) << "unreachable";
} }
void MessageExterns(Context<Descriptor> msg) { void MessageExterns(Context& ctx, const Descriptor& msg) {
switch (msg.opts().kernel) { switch (ctx.opts().kernel) {
case Kernel::kCpp: case Kernel::kCpp:
msg.Emit( ctx.Emit(
{ {
{"new_thunk", Thunk(msg, "new")}, {"new_thunk", Thunk(ctx, msg, "new")},
{"delete_thunk", Thunk(msg, "delete")}, {"delete_thunk", Thunk(ctx, msg, "delete")},
{"serialize_thunk", Thunk(msg, "serialize")}, {"serialize_thunk", Thunk(ctx, msg, "serialize")},
{"deserialize_thunk", Thunk(msg, "deserialize")}, {"deserialize_thunk", Thunk(ctx, msg, "deserialize")},
}, },
R"rs( R"rs(
fn $new_thunk$() -> $pbi$::RawMessage; fn $new_thunk$() -> $pbi$::RawMessage;
@ -134,11 +134,11 @@ void MessageExterns(Context<Descriptor> msg) {
return; return;
case Kernel::kUpb: case Kernel::kUpb:
msg.Emit( ctx.Emit(
{ {
{"new_thunk", Thunk(msg, "new")}, {"new_thunk", Thunk(ctx, msg, "new")},
{"serialize_thunk", Thunk(msg, "serialize")}, {"serialize_thunk", Thunk(ctx, msg, "serialize")},
{"deserialize_thunk", Thunk(msg, "parse")}, {"deserialize_thunk", Thunk(ctx, msg, "parse")},
}, },
R"rs( R"rs(
fn $new_thunk$(arena: $pbi$::RawArena) -> $pbi$::RawMessage; fn $new_thunk$(arena: $pbi$::RawArena) -> $pbi$::RawMessage;
@ -151,36 +151,37 @@ void MessageExterns(Context<Descriptor> msg) {
ABSL_LOG(FATAL) << "unreachable"; ABSL_LOG(FATAL) << "unreachable";
} }
void MessageDrop(Context<Descriptor> msg) { void MessageDrop(Context& ctx, const Descriptor& msg) {
if (msg.is_upb()) { if (ctx.is_upb()) {
// Nothing to do here; drop glue (which will run drop(self.arena) // Nothing to do here; drop glue (which will run drop(self.arena)
// automatically) is sufficient. // automatically) is sufficient.
return; return;
} }
msg.Emit({{"delete_thunk", Thunk(msg, "delete")}}, R"rs( ctx.Emit({{"delete_thunk", Thunk(ctx, msg, "delete")}}, R"rs(
unsafe { $delete_thunk$(self.inner.msg); } unsafe { $delete_thunk$(self.inner.msg); }
)rs"); )rs");
} }
void GetterForViewOrMut(Context<FieldDescriptor> field, bool is_mut) { void GetterForViewOrMut(Context& ctx, const FieldDescriptor& field,
auto fieldName = field.desc().name(); bool is_mut) {
auto fieldType = field.desc().type(); auto fieldName = field.name();
auto getter_thunk = Thunk(field, "get"); auto fieldType = field.type();
auto setter_thunk = Thunk(field, "set"); auto getter_thunk = Thunk(ctx, field, "get");
auto clearer_thunk = Thunk(field, "clear"); auto setter_thunk = Thunk(ctx, field, "set");
auto clearer_thunk = Thunk(ctx, field, "clear");
// If we're dealing with a Mut, the getter must be supplied // If we're dealing with a Mut, the getter must be supplied
// self.inner.msg() whereas a View has to be supplied self.msg // self.inner.msg() whereas a View has to be supplied self.msg
auto self = is_mut ? "self.inner.msg()" : "self.msg"; auto self = is_mut ? "self.inner.msg()" : "self.msg";
if (fieldType == FieldDescriptor::TYPE_MESSAGE) { if (fieldType == FieldDescriptor::TYPE_MESSAGE) {
Context<Descriptor> d = field.WithDesc(field.desc().message_type()); const Descriptor& msg = *field.message_type();
// TODO: support messages which are defined in other crates. // TODO: support messages which are defined in other crates.
if (!IsInCurrentlyGeneratingCrate(d)) { if (!IsInCurrentlyGeneratingCrate(ctx, msg)) {
return; return;
} }
auto prefix = "crate::" + GetCrateRelativeQualifiedPath(d); auto prefix = "crate::" + GetCrateRelativeQualifiedPath(ctx, msg);
field.Emit( ctx.Emit(
{ {
{"prefix", prefix}, {"prefix", prefix},
{"field", fieldName}, {"field", fieldName},
@ -190,8 +191,8 @@ void GetterForViewOrMut(Context<FieldDescriptor> field, bool is_mut) {
{ {
"view_body", "view_body",
[&] { [&] {
if (field.is_upb()) { if (ctx.is_upb()) {
field.Emit({}, R"rs( ctx.Emit({}, R"rs(
let submsg = unsafe { $getter_thunk$($self$) }; let submsg = unsafe { $getter_thunk$($self$) };
match submsg { match submsg {
None => $prefix$View::new($pbi$::Private, None => $prefix$View::new($pbi$::Private,
@ -200,7 +201,7 @@ void GetterForViewOrMut(Context<FieldDescriptor> field, bool is_mut) {
} }
)rs"); )rs");
} else { } else {
field.Emit({}, R"rs( ctx.Emit({}, R"rs(
let submsg = unsafe { $getter_thunk$($self$) }; let submsg = unsafe { $getter_thunk$($self$) };
$prefix$View::new($pbi$::Private, submsg) $prefix$View::new($pbi$::Private, submsg)
)rs"); )rs");
@ -216,18 +217,18 @@ void GetterForViewOrMut(Context<FieldDescriptor> field, bool is_mut) {
return; return;
} }
auto rsType = PrimitiveRsTypeName(field.desc()); auto rsType = PrimitiveRsTypeName(field);
if (fieldType == FieldDescriptor::TYPE_STRING || if (fieldType == FieldDescriptor::TYPE_STRING ||
fieldType == FieldDescriptor::TYPE_BYTES) { fieldType == FieldDescriptor::TYPE_BYTES) {
field.Emit({{"field", fieldName}, ctx.Emit({{"field", fieldName},
{"self", self}, {"self", self},
{"getter_thunk", getter_thunk}, {"getter_thunk", getter_thunk},
{"setter_thunk", setter_thunk}, {"setter_thunk", setter_thunk},
{"RsType", rsType}, {"RsType", rsType},
{"maybe_mutator", {"maybe_mutator",
[&] { [&] {
if (is_mut) { if (is_mut) {
field.Emit({}, R"rs( ctx.Emit({}, R"rs(
pub fn r#$field$_mut(&self) -> $pb$::Mut<'_, $RsType$> { pub fn r#$field$_mut(&self) -> $pb$::Mut<'_, $RsType$> {
static VTABLE: $pbi$::BytesMutVTable = static VTABLE: $pbi$::BytesMutVTable =
$pbi$::BytesMutVTable::new( $pbi$::BytesMutVTable::new(
@ -248,9 +249,9 @@ void GetterForViewOrMut(Context<FieldDescriptor> field, bool is_mut) {
} }
} }
)rs"); )rs");
} }
}}}, }}},
R"rs( R"rs(
pub fn r#$field$(&self) -> $pb$::View<'_, $RsType$> { pub fn r#$field$(&self) -> $pb$::View<'_, $RsType$> {
let s = unsafe { $getter_thunk$($self$).as_ref() }; let s = unsafe { $getter_thunk$($self$).as_ref() };
unsafe { __pb::ProtoStr::from_utf8_unchecked(s).into() } unsafe { __pb::ProtoStr::from_utf8_unchecked(s).into() }
@ -259,19 +260,19 @@ void GetterForViewOrMut(Context<FieldDescriptor> field, bool is_mut) {
$maybe_mutator$ $maybe_mutator$
)rs"); )rs");
} else { } else {
field.Emit({{"field", fieldName}, ctx.Emit({{"field", fieldName},
{"getter_thunk", getter_thunk}, {"getter_thunk", getter_thunk},
{"setter_thunk", setter_thunk}, {"setter_thunk", setter_thunk},
{"clearer_thunk", clearer_thunk}, {"clearer_thunk", clearer_thunk},
{"self", self}, {"self", self},
{"RsType", rsType}, {"RsType", rsType},
{"maybe_mutator", {"maybe_mutator",
[&] { [&] {
// TODO: once the rust public api is accessible, // TODO: once the rust public api is accessible,
// by tooling, ensure that this only appears for the // by tooling, ensure that this only appears for the
// mutational pathway // mutational pathway
if (is_mut && fieldType) { if (is_mut && fieldType) {
field.Emit({}, R"rs( ctx.Emit({}, R"rs(
pub fn r#$field$_mut(&self) -> $pb$::Mut<'_, $RsType$> { pub fn r#$field$_mut(&self) -> $pb$::Mut<'_, $RsType$> {
static VTABLE: $pbi$::PrimitiveVTable<$RsType$> = static VTABLE: $pbi$::PrimitiveVTable<$RsType$> =
$pbi$::PrimitiveVTable::new( $pbi$::PrimitiveVTable::new(
@ -290,9 +291,9 @@ void GetterForViewOrMut(Context<FieldDescriptor> field, bool is_mut) {
} }
} }
)rs"); )rs");
} }
}}}, }}},
R"rs( R"rs(
pub fn r#$field$(&self) -> $pb$::View<'_, $RsType$> { pub fn r#$field$(&self) -> $pb$::View<'_, $RsType$> {
unsafe { $getter_thunk$($self$) } unsafe { $getter_thunk$($self$) }
} }
@ -302,93 +303,89 @@ void GetterForViewOrMut(Context<FieldDescriptor> field, bool is_mut) {
} }
} }
void AccessorsForViewOrMut(Context<Descriptor> msg, bool is_mut) { void AccessorsForViewOrMut(Context& ctx, const Descriptor& msg, bool is_mut) {
for (int i = 0; i < msg.desc().field_count(); ++i) { for (int i = 0; i < msg.field_count(); ++i) {
auto field = msg.WithDesc(*msg.desc().field(i)); const FieldDescriptor& field = *msg.field(i);
if (field.desc().is_repeated()) continue; if (field.is_repeated()) continue;
// TODO - add cord support // TODO - add cord support
if (field.desc().options().has_ctype()) continue; if (field.options().has_ctype()) continue;
// TODO // TODO
if (field.desc().type() == FieldDescriptor::TYPE_ENUM || if (field.type() == FieldDescriptor::TYPE_ENUM ||
field.desc().type() == FieldDescriptor::TYPE_GROUP) field.type() == FieldDescriptor::TYPE_GROUP)
continue; continue;
GetterForViewOrMut(field, is_mut); GetterForViewOrMut(ctx, field, is_mut);
msg.printer().PrintRaw("\n"); ctx.printer().PrintRaw("\n");
} }
} }
} // namespace } // namespace
void GenerateRs(Context<Descriptor> msg) { void GenerateRs(Context& ctx, const Descriptor& msg) {
if (msg.desc().map_key() != nullptr) { if (msg.map_key() != nullptr) {
ABSL_LOG(WARNING) << "unsupported map field: " << msg.desc().full_name(); ABSL_LOG(WARNING) << "unsupported map field: " << msg.full_name();
return; return;
} }
msg.Emit( ctx.Emit({{"Msg", msg.name()},
{{"Msg", msg.desc().name()}, {"Msg::new", [&] { MessageNew(ctx, msg); }},
{"Msg::new", [&] { MessageNew(msg); }}, {"Msg::serialize", [&] { MessageSerialize(ctx, msg); }},
{"Msg::serialize", [&] { MessageSerialize(msg); }}, {"Msg::deserialize", [&] { MessageDeserialize(ctx, msg); }},
{"Msg::deserialize", [&] { MessageDeserialize(msg); }}, {"Msg::drop", [&] { MessageDrop(ctx, msg); }},
{"Msg::drop", [&] { MessageDrop(msg); }}, {"Msg_externs", [&] { MessageExterns(ctx, msg); }},
{"Msg_externs", [&] { MessageExterns(msg); }}, {"accessor_fns",
{"accessor_fns", [&] {
[&] { for (int i = 0; i < msg.field_count(); ++i) {
for (int i = 0; i < msg.desc().field_count(); ++i) { auto& field = *msg.field(i);
auto field = msg.WithDesc(*msg.desc().field(i)); ctx.Emit({{"comment", FieldInfoComment(ctx, field)}}, R"rs(
msg.Emit({{"comment", FieldInfoComment(field)}}, R"rs(
// $comment$ // $comment$
)rs"); )rs");
GenerateAccessorMsgImpl(field); GenerateAccessorMsgImpl(ctx, field);
msg.printer().PrintRaw("\n"); ctx.printer().PrintRaw("\n");
} }
}}, }},
{"oneof_accessor_fns", {"oneof_accessor_fns",
[&] { [&] {
for (int i = 0; i < msg.desc().real_oneof_decl_count(); ++i) { for (int i = 0; i < msg.real_oneof_decl_count(); ++i) {
GenerateOneofAccessors( GenerateOneofAccessors(ctx, *msg.real_oneof_decl(i));
msg.WithDesc(*msg.desc().real_oneof_decl(i))); ctx.printer().PrintRaw("\n");
msg.printer().PrintRaw("\n"); }
} }},
}}, {"accessor_externs",
{"accessor_externs", [&] {
[&] { for (int i = 0; i < msg.field_count(); ++i) {
for (int i = 0; i < msg.desc().field_count(); ++i) { GenerateAccessorExternC(ctx, *msg.field(i));
GenerateAccessorExternC(msg.WithDesc(*msg.desc().field(i))); ctx.printer().PrintRaw("\n");
msg.printer().PrintRaw("\n"); }
} }},
}}, {"oneof_externs",
{"oneof_externs", [&] {
[&] { for (int i = 0; i < msg.real_oneof_decl_count(); ++i) {
for (int i = 0; i < msg.desc().real_oneof_decl_count(); ++i) { GenerateOneofExternC(ctx, *msg.real_oneof_decl(i));
GenerateOneofExternC(msg.WithDesc(*msg.desc().real_oneof_decl(i))); ctx.printer().PrintRaw("\n");
msg.printer().PrintRaw("\n"); }
} }},
}}, {"nested_msgs",
{"nested_msgs", [&] {
[&] { // If we have no nested types or oneofs, bail out without
// If we have no nested types or oneofs, bail out without emitting // emitting an empty mod SomeMsg_.
// an empty mod SomeMsg_. if (msg.nested_type_count() == 0 &&
if (msg.desc().nested_type_count() == 0 && msg.real_oneof_decl_count() == 0) {
msg.desc().real_oneof_decl_count() == 0) { return;
return; }
} ctx.Emit(
msg.Emit( {{"Msg", msg.name()},
{{"Msg", msg.desc().name()}, {"nested_msgs",
{"nested_msgs", [&] {
[&] { for (int i = 0; i < msg.nested_type_count(); ++i) {
for (int i = 0; i < msg.desc().nested_type_count(); ++i) { GenerateRs(ctx, *msg.nested_type(i));
auto nested_msg = msg.WithDesc(msg.desc().nested_type(i)); }
GenerateRs(nested_msg); }},
} {"oneofs",
}}, [&] {
{"oneofs", for (int i = 0; i < msg.real_oneof_decl_count(); ++i) {
[&] { GenerateOneofDefinition(ctx, *msg.real_oneof_decl(i));
for (int i = 0; i < msg.desc().real_oneof_decl_count(); ++i) { }
GenerateOneofDefinition( }}},
msg.WithDesc(*msg.desc().real_oneof_decl(i))); R"rs(
}
}}},
R"rs(
#[allow(non_snake_case)] #[allow(non_snake_case)]
pub mod $Msg$_ { pub mod $Msg$_ {
$nested_msgs$ $nested_msgs$
@ -396,10 +393,12 @@ void GenerateRs(Context<Descriptor> msg) {
$oneofs$ $oneofs$
} // mod $Msg$_ } // mod $Msg$_
)rs"); )rs");
}}, }},
{"accessor_fns_for_views", [&] { AccessorsForViewOrMut(msg, false); }}, {"accessor_fns_for_views",
{"accessor_fns_for_muts", [&] { AccessorsForViewOrMut(msg, true); }}}, [&] { AccessorsForViewOrMut(ctx, msg, false); }},
R"rs( {"accessor_fns_for_muts",
[&] { AccessorsForViewOrMut(ctx, msg, true); }}},
R"rs(
#[allow(non_camel_case_types)] #[allow(non_camel_case_types)]
// TODO: Implement support for debug redaction // TODO: Implement support for debug redaction
#[derive(Debug)] #[derive(Debug)]
@ -542,9 +541,9 @@ void GenerateRs(Context<Descriptor> msg) {
$nested_msgs$ $nested_msgs$
)rs"); )rs");
if (msg.is_cpp()) { if (ctx.is_cpp()) {
msg.printer().PrintRaw("\n"); ctx.printer().PrintRaw("\n");
msg.Emit({{"Msg", msg.desc().name()}}, R"rs( ctx.Emit({{"Msg", msg.name()}}, R"rs(
impl $Msg$ { impl $Msg$ {
pub fn __unstable_wrap_cpp_grant_permission_to_break(msg: $pbi$::RawMessage) -> Self { pub fn __unstable_wrap_cpp_grant_permission_to_break(msg: $pbi$::RawMessage) -> Self {
Self { inner: $pbr$::MessageInner { msg } } Self { inner: $pbr$::MessageInner { msg } }
@ -558,39 +557,37 @@ void GenerateRs(Context<Descriptor> msg) {
} }
// Generates code for a particular message in `.pb.thunk.cc`. // Generates code for a particular message in `.pb.thunk.cc`.
void GenerateThunksCc(Context<Descriptor> msg) { void GenerateThunksCc(Context& ctx, const Descriptor& msg) {
ABSL_CHECK(msg.is_cpp()); ABSL_CHECK(ctx.is_cpp());
if (msg.desc().map_key() != nullptr) { if (msg.map_key() != nullptr) {
ABSL_LOG(WARNING) << "unsupported map field: " << msg.desc().full_name(); ABSL_LOG(WARNING) << "unsupported map field: " << msg.full_name();
return; return;
} }
msg.Emit( ctx.Emit(
{{"abi", "\"C\""}, // Workaround for syntax highlight bug in VSCode. {{"abi", "\"C\""}, // Workaround for syntax highlight bug in VSCode.
{"Msg", msg.desc().name()}, {"Msg", msg.name()},
{"QualifiedMsg", cpp::QualifiedClassName(&msg.desc())}, {"QualifiedMsg", cpp::QualifiedClassName(&msg)},
{"new_thunk", Thunk(msg, "new")}, {"new_thunk", Thunk(ctx, msg, "new")},
{"delete_thunk", Thunk(msg, "delete")}, {"delete_thunk", Thunk(ctx, msg, "delete")},
{"serialize_thunk", Thunk(msg, "serialize")}, {"serialize_thunk", Thunk(ctx, msg, "serialize")},
{"deserialize_thunk", Thunk(msg, "deserialize")}, {"deserialize_thunk", Thunk(ctx, msg, "deserialize")},
{"nested_msg_thunks", {"nested_msg_thunks",
[&] { [&] {
for (int i = 0; i < msg.desc().nested_type_count(); ++i) { for (int i = 0; i < msg.nested_type_count(); ++i) {
Context<Descriptor> nested_msg = GenerateThunksCc(ctx, *msg.nested_type(i));
msg.WithDesc(msg.desc().nested_type(i));
GenerateThunksCc(nested_msg);
} }
}}, }},
{"accessor_thunks", {"accessor_thunks",
[&] { [&] {
for (int i = 0; i < msg.desc().field_count(); ++i) { for (int i = 0; i < msg.field_count(); ++i) {
GenerateAccessorThunkCc(msg.WithDesc(*msg.desc().field(i))); GenerateAccessorThunkCc(ctx, *msg.field(i));
} }
}}, }},
{"oneof_thunks", {"oneof_thunks",
[&] { [&] {
for (int i = 0; i < msg.desc().real_oneof_decl_count(); ++i) { for (int i = 0; i < msg.real_oneof_decl_count(); ++i) {
GenerateOneofThunkCc(msg.WithDesc(*msg.desc().real_oneof_decl(i))); GenerateOneofThunkCc(ctx, *msg.real_oneof_decl(i));
} }
}}}, }}},
R"cc( R"cc(

@ -21,10 +21,10 @@ namespace compiler {
namespace rust { namespace rust {
// Generates code for a particular message in `.pb.rs`. // Generates code for a particular message in `.pb.rs`.
void GenerateRs(Context<Descriptor> msg); void GenerateRs(Context& ctx, const Descriptor& msg);
// Generates code for a particular message in `.pb.thunk.cc`. // Generates code for a particular message in `.pb.thunk.cc`.
void GenerateThunksCc(Context<Descriptor> msg); void GenerateThunksCc(Context& ctx, const Descriptor& msg);
} // namespace rust } // namespace rust
} // namespace compiler } // namespace compiler

@ -26,22 +26,23 @@ namespace protobuf {
namespace compiler { namespace compiler {
namespace rust { namespace rust {
namespace { namespace {
std::string GetUnderscoreDelimitedFullName(Context<Descriptor> msg) { std::string GetUnderscoreDelimitedFullName(Context& ctx,
std::string result = msg.desc().full_name(); const Descriptor& msg) {
std::string result = msg.full_name();
absl::StrReplaceAll({{".", "_"}}, &result); absl::StrReplaceAll({{".", "_"}}, &result);
return result; return result;
} }
} // namespace } // namespace
std::string GetCrateName(Context<FileDescriptor> dep) { std::string GetCrateName(Context& ctx, const FileDescriptor& dep) {
absl::string_view path = dep.desc().name(); absl::string_view path = dep.name();
auto basename = path.substr(path.rfind('/') + 1); auto basename = path.substr(path.rfind('/') + 1);
return absl::StrReplaceAll(basename, {{".", "_"}, {"-", "_"}}); return absl::StrReplaceAll(basename, {{".", "_"}, {"-", "_"}});
} }
std::string GetRsFile(Context<FileDescriptor> file) { std::string GetRsFile(Context& ctx, const FileDescriptor& file) {
auto basename = StripProto(file.desc().name()); auto basename = StripProto(file.name());
switch (auto k = file.opts().kernel) { switch (auto k = ctx.opts().kernel) {
case Kernel::kUpb: case Kernel::kUpb:
return absl::StrCat(basename, ".u.pb.rs"); return absl::StrCat(basename, ".u.pb.rs");
case Kernel::kCpp: case Kernel::kCpp:
@ -52,42 +53,41 @@ std::string GetRsFile(Context<FileDescriptor> file) {
} }
} }
std::string GetThunkCcFile(Context<FileDescriptor> file) { std::string GetThunkCcFile(Context& ctx, const FileDescriptor& file) {
auto basename = StripProto(file.desc().name()); auto basename = StripProto(file.name());
return absl::StrCat(basename, ".pb.thunks.cc"); return absl::StrCat(basename, ".pb.thunks.cc");
} }
std::string GetHeaderFile(Context<FileDescriptor> file) { std::string GetHeaderFile(Context& ctx, const FileDescriptor& file) {
auto basename = StripProto(file.desc().name()); auto basename = StripProto(file.name());
return absl::StrCat(basename, ".proto.h"); return absl::StrCat(basename, ".proto.h");
} }
namespace { namespace {
template <typename T> template <typename T>
std::string FieldPrefix(Context<T> field) { std::string FieldPrefix(Context& ctx, const T& field) {
// NOTE: When field.is_upb(), this functions outputs must match the symbols // NOTE: When ctx.is_upb(), this functions outputs must match the symbols
// that the upbc plugin generates exactly. Failure to do so correctly results // that the upbc plugin generates exactly. Failure to do so correctly results
// in a link-time failure. // in a link-time failure.
absl::string_view prefix = field.is_cpp() ? "__rust_proto_thunk__" : ""; absl::string_view prefix = ctx.is_cpp() ? "__rust_proto_thunk__" : "";
std::string thunk_prefix = std::string thunk_prefix = absl::StrCat(
absl::StrCat(prefix, GetUnderscoreDelimitedFullName( prefix, GetUnderscoreDelimitedFullName(ctx, *field.containing_type()));
field.WithDesc(field.desc().containing_type())));
return thunk_prefix; return thunk_prefix;
} }
template <typename T> template <typename T>
std::string Thunk(Context<T> field, absl::string_view op) { std::string Thunk(Context& ctx, const T& field, absl::string_view op) {
std::string thunk = FieldPrefix(field); std::string thunk = FieldPrefix(ctx, field);
absl::string_view format; absl::string_view format;
if (field.is_upb() && op == "get") { if (ctx.is_upb() && op == "get") {
// upb getter is simply the field name (no "get" in the name). // upb getter is simply the field name (no "get" in the name).
format = "_$1"; format = "_$1";
} else if (field.is_upb() && op == "get_mut") { } else if (ctx.is_upb() && op == "get_mut") {
// same as above, with with `mutable` prefix // same as above, with with `mutable` prefix
format = "_mutable_$1"; format = "_mutable_$1";
} else if (field.is_upb() && op == "case") { } else if (ctx.is_upb() && op == "case") {
// some upb functions are in the order x_op compared to has/set/clear which // some upb functions are in the order x_op compared to has/set/clear which
// are in the other order e.g. op_x. // are in the other order e.g. op_x.
format = "_$1_$0"; format = "_$1_$0";
@ -95,51 +95,53 @@ std::string Thunk(Context<T> field, absl::string_view op) {
format = "_$0_$1"; format = "_$0_$1";
} }
absl::SubstituteAndAppend(&thunk, format, op, field.desc().name()); absl::SubstituteAndAppend(&thunk, format, op, field.name());
return thunk; return thunk;
} }
std::string ThunkMapOrRepeated(Context<FieldDescriptor> field, std::string ThunkMapOrRepeated(Context& ctx, const FieldDescriptor& field,
absl::string_view op) { absl::string_view op) {
if (!field.is_upb()) { if (!ctx.is_upb()) {
return Thunk<FieldDescriptor>(field, op); return Thunk<FieldDescriptor>(ctx, field, op);
} }
std::string thunk = absl::StrCat("_", FieldPrefix(field)); std::string thunk = absl::StrCat("_", FieldPrefix(ctx, field));
absl::string_view format; absl::string_view format;
if (op == "get") { if (op == "get") {
format = field.desc().is_map() ? "_$1_upb_map" : "_$1_upb_array"; format = field.is_map() ? "_$1_upb_map" : "_$1_upb_array";
} else if (op == "get_mut") { } else if (op == "get_mut") {
format = format = field.is_map() ? "_$1_mutable_upb_map" : "_$1_mutable_upb_array";
field.desc().is_map() ? "_$1_mutable_upb_map" : "_$1_mutable_upb_array";
} else { } else {
return Thunk<FieldDescriptor>(field, op); return Thunk<FieldDescriptor>(ctx, field, op);
} }
absl::SubstituteAndAppend(&thunk, format, op, field.desc().name()); absl::SubstituteAndAppend(&thunk, format, op, field.name());
return thunk; return thunk;
} }
} // namespace } // namespace
std::string Thunk(Context<FieldDescriptor> field, absl::string_view op) { std::string Thunk(Context& ctx, const FieldDescriptor& field,
if (field.desc().is_map() || field.desc().is_repeated()) { absl::string_view op) {
return ThunkMapOrRepeated(field, op); if (field.is_map() || field.is_repeated()) {
return ThunkMapOrRepeated(ctx, field, op);
} }
return Thunk<FieldDescriptor>(field, op); return Thunk<FieldDescriptor>(ctx, field, op);
} }
std::string Thunk(Context<OneofDescriptor> field, absl::string_view op) { std::string Thunk(Context& ctx, const OneofDescriptor& field,
return Thunk<OneofDescriptor>(field, op); absl::string_view op) {
return Thunk<OneofDescriptor>(ctx, field, op);
} }
std::string Thunk(Context<Descriptor> msg, absl::string_view op) { std::string Thunk(Context& ctx, const Descriptor& msg, absl::string_view op) {
absl::string_view prefix = msg.is_cpp() ? "__rust_proto_thunk__" : ""; absl::string_view prefix = ctx.is_cpp() ? "__rust_proto_thunk__" : "";
return absl::StrCat(prefix, GetUnderscoreDelimitedFullName(msg), "_", op); return absl::StrCat(prefix, GetUnderscoreDelimitedFullName(ctx, msg), "_",
op);
} }
std::string PrimitiveRsTypeName(const FieldDescriptor& desc) { std::string PrimitiveRsTypeName(const FieldDescriptor& field) {
switch (desc.type()) { switch (field.type()) {
case FieldDescriptor::TYPE_BOOL: case FieldDescriptor::TYPE_BOOL:
return "bool"; return "bool";
case FieldDescriptor::TYPE_INT32: case FieldDescriptor::TYPE_INT32:
@ -167,7 +169,7 @@ std::string PrimitiveRsTypeName(const FieldDescriptor& desc) {
default: default:
break; break;
} }
ABSL_LOG(FATAL) << "Unsupported field type: " << desc.type_name(); ABSL_LOG(FATAL) << "Unsupported field type: " << field.type_name();
return ""; return "";
} }
@ -180,20 +182,18 @@ std::string PrimitiveRsTypeName(const FieldDescriptor& desc) {
// //
// If the message has no package and no containing messages then this returns // If the message has no package and no containing messages then this returns
// empty string. // empty string.
std::string RustModule(Context<Descriptor> msg) { std::string RustModule(Context& ctx, const Descriptor& msg) {
const Descriptor& desc = msg.desc();
std::vector<std::string> modules; std::vector<std::string> modules;
std::vector<std::string> package_modules = std::vector<std::string> package_modules =
absl::StrSplit(desc.file()->package(), '.', absl::SkipEmpty()); absl::StrSplit(msg.file()->package(), '.', absl::SkipEmpty());
modules.insert(modules.begin(), package_modules.begin(), modules.insert(modules.begin(), package_modules.begin(),
package_modules.end()); package_modules.end());
// Innermost to outermost order. // Innermost to outermost order.
std::vector<std::string> modules_from_containing_types; std::vector<std::string> modules_from_containing_types;
const Descriptor* parent = desc.containing_type(); const Descriptor* parent = msg.containing_type();
while (parent != nullptr) { while (parent != nullptr) {
modules_from_containing_types.push_back(absl::StrCat(parent->name(), "_")); modules_from_containing_types.push_back(absl::StrCat(parent->name(), "_"));
parent = parent->containing_type(); parent = parent->containing_type();
@ -213,27 +213,25 @@ std::string RustModule(Context<Descriptor> msg) {
return absl::StrJoin(modules, "::"); return absl::StrJoin(modules, "::");
} }
std::string RustInternalModuleName(Context<FileDescriptor> file) { std::string RustInternalModuleName(Context& ctx, const FileDescriptor& file) {
// TODO: Introduce a more robust mangling here to avoid conflicts // TODO: Introduce a more robust mangling here to avoid conflicts
// between `foo/bar/baz.proto` and `foo_bar/baz.proto`. // between `foo/bar/baz.proto` and `foo_bar/baz.proto`.
return absl::StrReplaceAll(StripProto(file.desc().name()), {{"/", "_"}}); return absl::StrReplaceAll(StripProto(file.name()), {{"/", "_"}});
} }
std::string GetCrateRelativeQualifiedPath(Context<Descriptor> msg) { std::string GetCrateRelativeQualifiedPath(Context& ctx, const Descriptor& msg) {
return absl::StrCat(RustModule(msg), msg.desc().name()); return absl::StrCat(RustModule(ctx, msg), msg.name());
} }
std::string FieldInfoComment(Context<FieldDescriptor> field) { std::string FieldInfoComment(Context& ctx, const FieldDescriptor& field) {
absl::string_view label = absl::string_view label = field.is_repeated() ? "repeated" : "optional";
field.desc().is_repeated() ? "repeated" : "optional"; std::string comment = absl::StrCat(field.name(), ": ", label, " ",
std::string comment = FieldDescriptor::TypeName(field.type()));
absl::StrCat(field.desc().name(), ": ", label, " ",
FieldDescriptor::TypeName(field.desc().type()));
if (auto* m = field.desc().message_type()) { if (auto* m = field.message_type()) {
absl::StrAppend(&comment, " ", m->full_name()); absl::StrAppend(&comment, " ", m->full_name());
} }
if (auto* m = field.desc().enum_type()) { if (auto* m = field.enum_type()) {
absl::StrAppend(&comment, " ", m->full_name()); absl::StrAppend(&comment, " ", m->full_name());
} }

@ -19,25 +19,27 @@ namespace google {
namespace protobuf { namespace protobuf {
namespace compiler { namespace compiler {
namespace rust { namespace rust {
std::string GetCrateName(Context<FileDescriptor> dep); std::string GetCrateName(Context& ctx, const FileDescriptor& dep);
std::string GetRsFile(Context<FileDescriptor> file); std::string GetRsFile(Context& ctx, const FileDescriptor& file);
std::string GetThunkCcFile(Context<FileDescriptor> file); std::string GetThunkCcFile(Context& ctx, const FileDescriptor& file);
std::string GetHeaderFile(Context<FileDescriptor> file); std::string GetHeaderFile(Context& ctx, const FileDescriptor& file);
std::string Thunk(Context<FieldDescriptor> field, absl::string_view op); std::string Thunk(Context& ctx, const FieldDescriptor& field,
std::string Thunk(Context<OneofDescriptor> field, absl::string_view op); absl::string_view op);
std::string Thunk(Context& ctx, const OneofDescriptor& field,
absl::string_view op);
std::string Thunk(Context<Descriptor> msg, absl::string_view op); std::string Thunk(Context& ctx, const Descriptor& msg, absl::string_view op);
std::string PrimitiveRsTypeName(const FieldDescriptor& desc); std::string PrimitiveRsTypeName(const FieldDescriptor& field);
std::string FieldInfoComment(Context<FieldDescriptor> field); std::string FieldInfoComment(Context& ctx, const FieldDescriptor& field);
std::string RustModule(Context<Descriptor> msg); std::string RustModule(Context& ctx, const Descriptor& msg);
std::string RustInternalModuleName(Context<FileDescriptor> file); std::string RustInternalModuleName(Context& ctx, const FileDescriptor& file);
std::string GetCrateRelativeQualifiedPath(Context<Descriptor> msg); std::string GetCrateRelativeQualifiedPath(Context& ctx, const Descriptor& msg);
} // namespace rust } // namespace rust
} // namespace compiler } // namespace compiler

@ -78,27 +78,25 @@ std::string ToCamelCase(absl::string_view name) {
return cpp::UnderscoresToCamelCase(name, /* upper initial letter */ true); return cpp::UnderscoresToCamelCase(name, /* upper initial letter */ true);
} }
std::string oneofViewEnumRsName(const OneofDescriptor& desc) { std::string oneofViewEnumRsName(const OneofDescriptor& oneof) {
return ToCamelCase(desc.name()); return ToCamelCase(oneof.name());
} }
std::string oneofMutEnumRsName(const OneofDescriptor& desc) { std::string oneofMutEnumRsName(const OneofDescriptor& oneof) {
return ToCamelCase(desc.name()) + "Mut"; return ToCamelCase(oneof.name()) + "Mut";
} }
std::string oneofCaseEnumName(const OneofDescriptor& desc) { std::string oneofCaseEnumName(const OneofDescriptor& oneof) {
// Note: This is the name used for the cpp Case enum, we use it for both // Note: This is the name used for the cpp Case enum, we use it for both
// the Rust Case enum as well as for the cpp case enum in the cpp thunk. // the Rust Case enum as well as for the cpp case enum in the cpp thunk.
return ToCamelCase(desc.name()) + "Case"; return ToCamelCase(oneof.name()) + "Case";
} }
std::string RsTypeNameView(Context<FieldDescriptor> field) { std::string RsTypeNameView(Context& ctx, const FieldDescriptor& field) {
const auto& desc = field.desc(); if (field.options().has_ctype()) {
if (desc.options().has_ctype()) {
return ""; // TODO: b/308792377 - ctype fields not supported yet. return ""; // TODO: b/308792377 - ctype fields not supported yet.
} }
switch (desc.type()) { switch (field.type()) {
case FieldDescriptor::TYPE_INT32: case FieldDescriptor::TYPE_INT32:
case FieldDescriptor::TYPE_INT64: case FieldDescriptor::TYPE_INT64:
case FieldDescriptor::TYPE_FIXED32: case FieldDescriptor::TYPE_FIXED32:
@ -112,7 +110,7 @@ std::string RsTypeNameView(Context<FieldDescriptor> field) {
case FieldDescriptor::TYPE_FLOAT: case FieldDescriptor::TYPE_FLOAT:
case FieldDescriptor::TYPE_DOUBLE: case FieldDescriptor::TYPE_DOUBLE:
case FieldDescriptor::TYPE_BOOL: case FieldDescriptor::TYPE_BOOL:
return PrimitiveRsTypeName(desc); return PrimitiveRsTypeName(field);
case FieldDescriptor::TYPE_BYTES: case FieldDescriptor::TYPE_BYTES:
return "&'msg [u8]"; return "&'msg [u8]";
case FieldDescriptor::TYPE_STRING: case FieldDescriptor::TYPE_STRING:
@ -120,23 +118,21 @@ std::string RsTypeNameView(Context<FieldDescriptor> field) {
case FieldDescriptor::TYPE_MESSAGE: case FieldDescriptor::TYPE_MESSAGE:
return absl::StrCat( return absl::StrCat(
"::__pb::View<'msg, crate::", "::__pb::View<'msg, crate::",
GetCrateRelativeQualifiedPath(field.WithDesc(desc.message_type())), GetCrateRelativeQualifiedPath(ctx, *field.message_type()), ">");
">");
case FieldDescriptor::TYPE_ENUM: // TODO: b/300257770 - Support enums. case FieldDescriptor::TYPE_ENUM: // TODO: b/300257770 - Support enums.
case FieldDescriptor::TYPE_GROUP: // Not supported yet. case FieldDescriptor::TYPE_GROUP: // Not supported yet.
return ""; return "";
} }
ABSL_LOG(FATAL) << "Unexpected field type: " << desc.type_name(); ABSL_LOG(FATAL) << "Unexpected field type: " << field.type_name();
return ""; return "";
} }
std::string RsTypeNameMut(Context<FieldDescriptor> field) { std::string RsTypeNameMut(Context& ctx, const FieldDescriptor& field) {
const auto& desc = field.desc(); if (field.options().has_ctype()) {
if (desc.options().has_ctype()) {
return ""; // TODO: b/308792377 - ctype fields not supported yet. return ""; // TODO: b/308792377 - ctype fields not supported yet.
} }
switch (desc.type()) { switch (field.type()) {
case FieldDescriptor::TYPE_INT32: case FieldDescriptor::TYPE_INT32:
case FieldDescriptor::TYPE_INT64: case FieldDescriptor::TYPE_INT64:
case FieldDescriptor::TYPE_FIXED32: case FieldDescriptor::TYPE_FIXED32:
@ -152,56 +148,54 @@ std::string RsTypeNameMut(Context<FieldDescriptor> field) {
case FieldDescriptor::TYPE_BOOL: case FieldDescriptor::TYPE_BOOL:
case FieldDescriptor::TYPE_BYTES: case FieldDescriptor::TYPE_BYTES:
case FieldDescriptor::TYPE_STRING: case FieldDescriptor::TYPE_STRING:
return absl::StrCat("::__pb::Mut<'msg, ", PrimitiveRsTypeName(desc), ">"); return absl::StrCat("::__pb::Mut<'msg, ", PrimitiveRsTypeName(field),
">");
case FieldDescriptor::TYPE_MESSAGE: case FieldDescriptor::TYPE_MESSAGE:
return absl::StrCat( return absl::StrCat(
"::__pb::Mut<'msg, crate::", "::__pb::Mut<'msg, crate::",
GetCrateRelativeQualifiedPath(field.WithDesc(desc.message_type())), GetCrateRelativeQualifiedPath(ctx, *field.message_type()), ">");
">");
case FieldDescriptor::TYPE_ENUM: // TODO: b/300257770 - Support enums. case FieldDescriptor::TYPE_ENUM: // TODO: b/300257770 - Support enums.
case FieldDescriptor::TYPE_GROUP: // Not supported yet. case FieldDescriptor::TYPE_GROUP: // Not supported yet.
return ""; return "";
} }
ABSL_LOG(FATAL) << "Unexpected field type: " << desc.type_name(); ABSL_LOG(FATAL) << "Unexpected field type: " << field.type_name();
return ""; return "";
} }
} // namespace } // namespace
void GenerateOneofDefinition(Context<OneofDescriptor> oneof) { void GenerateOneofDefinition(Context& ctx, const OneofDescriptor& oneof) {
const auto& desc = oneof.desc(); ctx.Emit(
{{"view_enum_name", oneofViewEnumRsName(oneof)},
oneof.Emit( {"mut_enum_name", oneofMutEnumRsName(oneof)},
{{"view_enum_name", oneofViewEnumRsName(desc)},
{"mut_enum_name", oneofMutEnumRsName(desc)},
{"view_fields", {"view_fields",
[&] { [&] {
for (int i = 0; i < desc.field_count(); ++i) { for (int i = 0; i < oneof.field_count(); ++i) {
const auto& field = *desc.field(i); auto& field = *oneof.field(i);
std::string rs_type = RsTypeNameView(oneof.WithDesc(field)); std::string rs_type = RsTypeNameView(ctx, field);
if (rs_type.empty()) { if (rs_type.empty()) {
continue; continue;
} }
oneof.Emit({{"name", ToCamelCase(field.name())}, ctx.Emit({{"name", ToCamelCase(field.name())},
{"type", rs_type}, {"type", rs_type},
{"number", std::to_string(field.number())}}, {"number", std::to_string(field.number())}},
R"rs($name$($type$) = $number$, R"rs($name$($type$) = $number$,
)rs"); )rs");
} }
}}, }},
{"mut_fields", {"mut_fields",
[&] { [&] {
for (int i = 0; i < desc.field_count(); ++i) { for (int i = 0; i < oneof.field_count(); ++i) {
const auto& field = *desc.field(i); auto& field = *oneof.field(i);
std::string rs_type = RsTypeNameMut(oneof.WithDesc(field)); std::string rs_type = RsTypeNameMut(ctx, field);
if (rs_type.empty()) { if (rs_type.empty()) {
continue; continue;
} }
oneof.Emit({{"name", ToCamelCase(field.name())}, ctx.Emit({{"name", ToCamelCase(field.name())},
{"type", rs_type}, {"type", rs_type},
{"number", std::to_string(field.number())}}, {"number", std::to_string(field.number())}},
R"rs($name$($type$) = $number$, R"rs($name$($type$) = $number$,
)rs"); )rs");
} }
}}}, }}},
@ -236,18 +230,18 @@ void GenerateOneofDefinition(Context<OneofDescriptor> oneof) {
// Note: This enum is used as the Thunk return type for getting which case is // Note: This enum is used as the Thunk return type for getting which case is
// used: it exactly matches the generate case enum that both cpp and upb use. // used: it exactly matches the generate case enum that both cpp and upb use.
oneof.Emit({{"case_enum_name", oneofCaseEnumName(desc)}, ctx.Emit({{"case_enum_name", oneofCaseEnumName(oneof)},
{"cases", {"cases",
[&] { [&] {
for (int i = 0; i < desc.field_count(); ++i) { for (int i = 0; i < oneof.field_count(); ++i) {
const auto& field = desc.field(i); auto& field = *oneof.field(i);
oneof.Emit({{"name", ToCamelCase(field->name())}, ctx.Emit({{"name", ToCamelCase(field.name())},
{"number", std::to_string(field->number())}}, {"number", std::to_string(field.number())}},
R"rs($name$ = $number$, R"rs($name$ = $number$,
)rs"); )rs");
} }
}}}, }}},
R"rs( R"rs(
#[repr(C)] #[repr(C)]
#[derive(Debug, Copy, Clone, PartialEq, Eq)] #[derive(Debug, Copy, Clone, PartialEq, Eq)]
pub(super) enum $case_enum_name$ { pub(super) enum $case_enum_name$ {
@ -260,23 +254,21 @@ void GenerateOneofDefinition(Context<OneofDescriptor> oneof) {
)rs"); )rs");
} }
void GenerateOneofAccessors(Context<OneofDescriptor> oneof) { void GenerateOneofAccessors(Context& ctx, const OneofDescriptor& oneof) {
const auto& desc = oneof.desc(); ctx.Emit(
{{"oneof_name", oneof.name()},
oneof.Emit( {"view_enum_name", oneofViewEnumRsName(oneof)},
{{"oneof_name", desc.name()}, {"mut_enum_name", oneofMutEnumRsName(oneof)},
{"view_enum_name", oneofViewEnumRsName(desc)}, {"case_enum_name", oneofCaseEnumName(oneof)},
{"mut_enum_name", oneofMutEnumRsName(desc)},
{"case_enum_name", oneofCaseEnumName(desc)},
{"view_cases", {"view_cases",
[&] { [&] {
for (int i = 0; i < desc.field_count(); ++i) { for (int i = 0; i < oneof.field_count(); ++i) {
const auto& field = *desc.field(i); auto& field = *oneof.field(i);
std::string rs_type = RsTypeNameView(oneof.WithDesc(field)); std::string rs_type = RsTypeNameView(ctx, field);
if (rs_type.empty()) { if (rs_type.empty()) {
continue; continue;
} }
oneof.Emit( ctx.Emit(
{ {
{"case", ToCamelCase(field.name())}, {"case", ToCamelCase(field.name())},
{"rs_getter", field.name()}, {"rs_getter", field.name()},
@ -288,13 +280,13 @@ void GenerateOneofAccessors(Context<OneofDescriptor> oneof) {
}}, }},
{"mut_cases", {"mut_cases",
[&] { [&] {
for (int i = 0; i < desc.field_count(); ++i) { for (int i = 0; i < oneof.field_count(); ++i) {
const auto& field = *desc.field(i); auto& field = *oneof.field(i);
std::string rs_type = RsTypeNameMut(oneof.WithDesc(field)); std::string rs_type = RsTypeNameMut(ctx, field);
if (rs_type.empty()) { if (rs_type.empty()) {
continue; continue;
} }
oneof.Emit( ctx.Emit(
{{"case", ToCamelCase(field.name())}, {{"case", ToCamelCase(field.name())},
{"rs_mut_getter", field.name() + "_mut"}, {"rs_mut_getter", field.name() + "_mut"},
{"type", rs_type}, {"type", rs_type},
@ -321,7 +313,7 @@ void GenerateOneofAccessors(Context<OneofDescriptor> oneof) {
$Msg$_::$mut_enum_name$::$case$(self.$rs_mut_getter$()$into_mut_transform$), )rs"); $Msg$_::$mut_enum_name$::$case$(self.$rs_mut_getter$()$into_mut_transform$), )rs");
} }
}}, }},
{"case_thunk", Thunk(oneof, "case")}}, {"case_thunk", Thunk(ctx, oneof, "case")}},
R"rs( R"rs(
pub fn r#$oneof_name$(&self) -> $Msg$_::$view_enum_name$ { pub fn r#$oneof_name$(&self) -> $Msg$_::$view_enum_name$ {
match unsafe { $case_thunk$(self.inner.msg) } { match unsafe { $case_thunk$(self.inner.msg) } {
@ -340,26 +332,24 @@ void GenerateOneofAccessors(Context<OneofDescriptor> oneof) {
)rs"); )rs");
} }
void GenerateOneofExternC(Context<OneofDescriptor> oneof) { void GenerateOneofExternC(Context& ctx, const OneofDescriptor& oneof) {
const auto& desc = oneof.desc(); ctx.Emit(
oneof.Emit(
{ {
{"case_enum_rs_name", oneofCaseEnumName(desc)}, {"case_enum_rs_name", oneofCaseEnumName(oneof)},
{"case_thunk", Thunk(oneof, "case")}, {"case_thunk", Thunk(ctx, oneof, "case")},
}, },
R"rs( R"rs(
fn $case_thunk$(raw_msg: $pbi$::RawMessage) -> $Msg$_::$case_enum_rs_name$; fn $case_thunk$(raw_msg: $pbi$::RawMessage) -> $Msg$_::$case_enum_rs_name$;
)rs"); )rs");
} }
void GenerateOneofThunkCc(Context<OneofDescriptor> oneof) { void GenerateOneofThunkCc(Context& ctx, const OneofDescriptor& oneof) {
const auto& desc = oneof.desc(); ctx.Emit(
oneof.Emit(
{ {
{"oneof_name", desc.name()}, {"oneof_name", oneof.name()},
{"case_enum_name", oneofCaseEnumName(desc)}, {"case_enum_name", oneofCaseEnumName(oneof)},
{"case_thunk", Thunk(oneof, "case")}, {"case_thunk", Thunk(ctx, oneof, "case")},
{"QualifiedMsg", cpp::QualifiedClassName(desc.containing_type())}, {"QualifiedMsg", cpp::QualifiedClassName(oneof.containing_type())},
}, },
R"cc( R"cc(
$QualifiedMsg$::$case_enum_name$ $case_thunk$($QualifiedMsg$* msg) { $QualifiedMsg$::$case_enum_name$ $case_thunk$($QualifiedMsg$* msg) {

@ -16,10 +16,10 @@ namespace protobuf {
namespace compiler { namespace compiler {
namespace rust { namespace rust {
void GenerateOneofDefinition(Context<OneofDescriptor> oneof); void GenerateOneofDefinition(Context& ctx, const OneofDescriptor& oneof);
void GenerateOneofAccessors(Context<OneofDescriptor> oneof); void GenerateOneofAccessors(Context& ctx, const OneofDescriptor& oneof);
void GenerateOneofExternC(Context<OneofDescriptor> oneof); void GenerateOneofExternC(Context& ctx, const OneofDescriptor& oneof);
void GenerateOneofThunkCc(Context<OneofDescriptor> oneof); void GenerateOneofThunkCc(Context& ctx, const OneofDescriptor& oneof);
} // namespace rust } // namespace rust
} // namespace compiler } // namespace compiler

Loading…
Cancel
Save