Decouple Context from the Descriptor

PiperOrigin-RevId: 592029759
pull/15107/head
Protobuf Team Bot 1 year ago committed by Copybara-Service
parent 52ee619733
commit 542ca772fa
  1. 63
      src/google/protobuf/compiler/rust/accessors/accessor_generator.h
  2. 35
      src/google/protobuf/compiler/rust/accessors/accessors.cc
  3. 6
      src/google/protobuf/compiler/rust/accessors/accessors.h
  4. 42
      src/google/protobuf/compiler/rust/accessors/helpers.cc
  5. 3
      src/google/protobuf/compiler/rust/accessors/helpers.h
  6. 87
      src/google/protobuf/compiler/rust/accessors/map.cc
  7. 124
      src/google/protobuf/compiler/rust/accessors/repeated_scalar.cc
  8. 85
      src/google/protobuf/compiler/rust/accessors/singular_message.cc
  9. 112
      src/google/protobuf/compiler/rust/accessors/singular_scalar.cc
  10. 139
      src/google/protobuf/compiler/rust/accessors/singular_string.cc
  11. 7
      src/google/protobuf/compiler/rust/accessors/unsupported_field.cc
  12. 9
      src/google/protobuf/compiler/rust/context.cc
  13. 42
      src/google/protobuf/compiler/rust/context.h
  14. 165
      src/google/protobuf/compiler/rust/generator.cc
  15. 337
      src/google/protobuf/compiler/rust/message.cc
  16. 4
      src/google/protobuf/compiler/rust/message.h
  17. 120
      src/google/protobuf/compiler/rust/naming.cc
  18. 26
      src/google/protobuf/compiler/rust/naming.h
  19. 156
      src/google/protobuf/compiler/rust/oneof.cc
  20. 8
      src/google/protobuf/compiler/rust/oneof.h

@ -26,25 +26,26 @@ class AccessorGenerator {
AccessorGenerator() = default;
virtual ~AccessorGenerator() = default;
AccessorGenerator(const AccessorGenerator &) = delete;
AccessorGenerator(AccessorGenerator &&) = delete;
AccessorGenerator &operator=(const AccessorGenerator &) = delete;
AccessorGenerator &operator=(AccessorGenerator &&) = delete;
AccessorGenerator(const AccessorGenerator&) = delete;
AccessorGenerator(AccessorGenerator&&) = delete;
AccessorGenerator& operator=(const AccessorGenerator&) = delete;
AccessorGenerator& operator=(AccessorGenerator&&) = delete;
// Constructs a generator for the given field.
//
// Returns `nullptr` if there is no known generator for this field.
static std::unique_ptr<AccessorGenerator> For(Context<FieldDescriptor> field);
static std::unique_ptr<AccessorGenerator> For(Context& ctx,
const FieldDescriptor& field);
void GenerateMsgImpl(Context<FieldDescriptor> field) const {
InMsgImpl(field);
void GenerateMsgImpl(Context& ctx, const FieldDescriptor& field) const {
InMsgImpl(ctx, field);
}
void GenerateExternC(Context<FieldDescriptor> field) const {
InExternC(field);
void GenerateExternC(Context& ctx, const FieldDescriptor& field) const {
InExternC(ctx, field);
}
void GenerateThunkCc(Context<FieldDescriptor> field) const {
ABSL_CHECK(field.is_cpp());
InThunkCc(field);
void GenerateThunkCc(Context& ctx, const FieldDescriptor& field) const {
ABSL_CHECK(ctx.is_cpp());
InThunkCc(ctx, field);
}
private:
@ -54,53 +55,53 @@ class AccessorGenerator {
// prologue to inject variables automatically.
// Called inside the main inherent `impl Msg {}` block.
virtual void InMsgImpl(Context<FieldDescriptor> field) const {}
virtual void InMsgImpl(Context& ctx, const FieldDescriptor& field) const {}
// Called inside of a message's `extern "C" {}` block.
virtual void InExternC(Context<FieldDescriptor> field) const {}
virtual void InExternC(Context& ctx, const FieldDescriptor& field) const {}
// Called inside of an `extern "C" {}` block in the `.thunk.cc` file, if such
// a file is being generated.
virtual void InThunkCc(Context<FieldDescriptor> field) const {}
virtual void InThunkCc(Context& ctx, const FieldDescriptor& field) const {}
};
class SingularScalar final : public AccessorGenerator {
public:
~SingularScalar() override = default;
void InMsgImpl(Context<FieldDescriptor> field) const override;
void InExternC(Context<FieldDescriptor> field) const override;
void InThunkCc(Context<FieldDescriptor> field) const override;
void InMsgImpl(Context& ctx, const FieldDescriptor& field) const override;
void InExternC(Context& ctx, const FieldDescriptor& field) const override;
void InThunkCc(Context& ctx, const FieldDescriptor& field) const override;
};
class SingularString final : public AccessorGenerator {
public:
~SingularString() override = default;
void InMsgImpl(Context<FieldDescriptor> field) const override;
void InExternC(Context<FieldDescriptor> field) const override;
void InThunkCc(Context<FieldDescriptor> field) const override;
void InMsgImpl(Context& ctx, const FieldDescriptor& field) const override;
void InExternC(Context& ctx, const FieldDescriptor& field) const override;
void InThunkCc(Context& ctx, const FieldDescriptor& field) const override;
};
class SingularMessage final : public AccessorGenerator {
public:
~SingularMessage() override = default;
void InMsgImpl(Context<FieldDescriptor> field) const override;
void InExternC(Context<FieldDescriptor> field) const override;
void InThunkCc(Context<FieldDescriptor> field) const override;
void InMsgImpl(Context& ctx, const FieldDescriptor& field) const override;
void InExternC(Context& ctx, const FieldDescriptor& field) const override;
void InThunkCc(Context& ctx, const FieldDescriptor& field) const override;
};
class RepeatedScalar final : public AccessorGenerator {
public:
~RepeatedScalar() override = default;
void InMsgImpl(Context<FieldDescriptor> field) const override;
void InExternC(Context<FieldDescriptor> field) const override;
void InThunkCc(Context<FieldDescriptor> field) const override;
void InMsgImpl(Context& ctx, const FieldDescriptor& field) const override;
void InExternC(Context& ctx, const FieldDescriptor& field) const override;
void InThunkCc(Context& ctx, const FieldDescriptor& field) const override;
};
class UnsupportedField final : public AccessorGenerator {
public:
explicit UnsupportedField(std::string reason) : reason_(std::move(reason)) {}
~UnsupportedField() override = default;
void InMsgImpl(Context<FieldDescriptor> field) const override;
void InMsgImpl(Context& ctx, const FieldDescriptor& field) const override;
private:
std::string reason_;
@ -109,9 +110,9 @@ class UnsupportedField final : public AccessorGenerator {
class Map final : public AccessorGenerator {
public:
~Map() override = default;
void InMsgImpl(Context<FieldDescriptor> field) const override;
void InExternC(Context<FieldDescriptor> field) const override;
void InThunkCc(Context<FieldDescriptor> field) const override;
void InMsgImpl(Context& ctx, const FieldDescriptor& field) const override;
void InExternC(Context& ctx, const FieldDescriptor& field) const override;
void InThunkCc(Context& ctx, const FieldDescriptor& field) const override;
};
} // namespace rust

@ -23,17 +23,16 @@ namespace rust {
namespace {
std::unique_ptr<AccessorGenerator> AccessorGeneratorFor(
Context<FieldDescriptor> field) {
const FieldDescriptor& desc = field.desc();
Context& ctx, const FieldDescriptor& field) {
// TODO: We do not support [ctype=FOO] (used to set the field
// type in C++ to cord or string_piece) in V0.6 API.
if (desc.options().has_ctype()) {
if (field.options().has_ctype()) {
return std::make_unique<UnsupportedField>(
"fields with ctype not supported");
}
if (desc.is_map()) {
auto value_type = desc.message_type()->map_value()->type();
if (field.is_map()) {
auto value_type = field.message_type()->map_value()->type();
switch (value_type) {
case FieldDescriptor::TYPE_BYTES:
case FieldDescriptor::TYPE_ENUM:
@ -46,7 +45,7 @@ std::unique_ptr<AccessorGenerator> AccessorGeneratorFor(
}
}
switch (desc.type()) {
switch (field.type()) {
case FieldDescriptor::TYPE_INT32:
case FieldDescriptor::TYPE_INT64:
case FieldDescriptor::TYPE_FIXED32:
@ -60,22 +59,22 @@ std::unique_ptr<AccessorGenerator> AccessorGeneratorFor(
case FieldDescriptor::TYPE_FLOAT:
case FieldDescriptor::TYPE_DOUBLE:
case FieldDescriptor::TYPE_BOOL:
if (desc.is_repeated()) {
if (field.is_repeated()) {
return std::make_unique<RepeatedScalar>();
}
return std::make_unique<SingularScalar>();
case FieldDescriptor::TYPE_BYTES:
case FieldDescriptor::TYPE_STRING:
if (desc.is_repeated()) {
if (field.is_repeated()) {
return std::make_unique<UnsupportedField>("repeated str not supported");
}
return std::make_unique<SingularString>();
case FieldDescriptor::TYPE_MESSAGE:
if (desc.is_repeated()) {
if (field.is_repeated()) {
return std::make_unique<UnsupportedField>("repeated msg not supported");
}
if (!field.generator_context().is_file_in_current_crate(
desc.message_type()->file())) {
if (!ctx.generator_context().is_file_in_current_crate(
*field.message_type()->file())) {
return std::make_unique<UnsupportedField>(
"message fields that are imported from another proto_library"
" (defined in a separate Rust crate) are not supported");
@ -89,21 +88,21 @@ std::unique_ptr<AccessorGenerator> AccessorGeneratorFor(
return std::make_unique<UnsupportedField>("group not supported");
}
ABSL_LOG(FATAL) << "Unexpected field type: " << desc.type();
ABSL_LOG(FATAL) << "Unexpected field type: " << field.type();
}
} // namespace
void GenerateAccessorMsgImpl(Context<FieldDescriptor> field) {
AccessorGeneratorFor(field)->GenerateMsgImpl(field);
void GenerateAccessorMsgImpl(Context& ctx, const FieldDescriptor& field) {
AccessorGeneratorFor(ctx, field)->GenerateMsgImpl(ctx, field);
}
void GenerateAccessorExternC(Context<FieldDescriptor> field) {
AccessorGeneratorFor(field)->GenerateExternC(field);
void GenerateAccessorExternC(Context& ctx, const FieldDescriptor& field) {
AccessorGeneratorFor(ctx, field)->GenerateExternC(ctx, field);
}
void GenerateAccessorThunkCc(Context<FieldDescriptor> field) {
AccessorGeneratorFor(field)->GenerateThunkCc(field);
void GenerateAccessorThunkCc(Context& ctx, const FieldDescriptor& field) {
AccessorGeneratorFor(ctx, field)->GenerateThunkCc(ctx, field);
}
} // namespace rust

@ -16,9 +16,9 @@ namespace protobuf {
namespace compiler {
namespace rust {
void GenerateAccessorMsgImpl(Context<FieldDescriptor> field);
void GenerateAccessorExternC(Context<FieldDescriptor> field);
void GenerateAccessorThunkCc(Context<FieldDescriptor> field);
void GenerateAccessorMsgImpl(Context& ctx, const FieldDescriptor& field);
void GenerateAccessorExternC(Context& ctx, const FieldDescriptor& field);
void GenerateAccessorThunkCc(Context& ctx, const FieldDescriptor& field);
} // namespace rust
} // namespace compiler

@ -15,7 +15,6 @@
#include "absl/strings/escaping.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/str_format.h"
#include "google/protobuf/compiler/rust/context.h"
#include "google/protobuf/descriptor.h"
#include "google/protobuf/io/strtod.h"
@ -24,33 +23,32 @@ namespace protobuf {
namespace compiler {
namespace rust {
std::string DefaultValue(Context<FieldDescriptor> field) {
switch (field.desc().type()) {
std::string DefaultValue(const FieldDescriptor& field) {
switch (field.type()) {
case FieldDescriptor::TYPE_DOUBLE:
if (std::isfinite(field.desc().default_value_double())) {
return absl::StrCat(io::SimpleDtoa(field.desc().default_value_double()),
if (std::isfinite(field.default_value_double())) {
return absl::StrCat(io::SimpleDtoa(field.default_value_double()),
"f64");
} else if (std::isnan(field.desc().default_value_double())) {
} else if (std::isnan(field.default_value_double())) {
return std::string("f64::NAN");
} else if (field.desc().default_value_double() ==
} else if (field.default_value_double() ==
std::numeric_limits<double>::infinity()) {
return std::string("f64::INFINITY");
} else if (field.desc().default_value_double() ==
} else if (field.default_value_double() ==
-std::numeric_limits<double>::infinity()) {
return std::string("f64::NEG_INFINITY");
} else {
ABSL_LOG(FATAL) << "unreachable";
}
case FieldDescriptor::TYPE_FLOAT:
if (std::isfinite(field.desc().default_value_float())) {
return absl::StrCat(io::SimpleFtoa(field.desc().default_value_float()),
"f32");
} else if (std::isnan(field.desc().default_value_float())) {
if (std::isfinite(field.default_value_float())) {
return absl::StrCat(io::SimpleFtoa(field.default_value_float()), "f32");
} else if (std::isnan(field.default_value_float())) {
return std::string("f32::NAN");
} else if (field.desc().default_value_float() ==
} else if (field.default_value_float() ==
std::numeric_limits<float>::infinity()) {
return std::string("f32::INFINITY");
} else if (field.desc().default_value_float() ==
} else if (field.default_value_float() ==
-std::numeric_limits<float>::infinity()) {
return std::string("f32::NEG_INFINITY");
} else {
@ -59,27 +57,27 @@ std::string DefaultValue(Context<FieldDescriptor> field) {
case FieldDescriptor::TYPE_INT32:
case FieldDescriptor::TYPE_SFIXED32:
case FieldDescriptor::TYPE_SINT32:
return absl::StrFormat("%d", field.desc().default_value_int32());
return absl::StrFormat("%d", field.default_value_int32());
case FieldDescriptor::TYPE_INT64:
case FieldDescriptor::TYPE_SFIXED64:
case FieldDescriptor::TYPE_SINT64:
return absl::StrFormat("%d", field.desc().default_value_int64());
return absl::StrFormat("%d", field.default_value_int64());
case FieldDescriptor::TYPE_FIXED64:
case FieldDescriptor::TYPE_UINT64:
return absl::StrFormat("%u", field.desc().default_value_uint64());
return absl::StrFormat("%u", field.default_value_uint64());
case FieldDescriptor::TYPE_FIXED32:
case FieldDescriptor::TYPE_UINT32:
return absl::StrFormat("%u", field.desc().default_value_uint32());
return absl::StrFormat("%u", field.default_value_uint32());
case FieldDescriptor::TYPE_BOOL:
return absl::StrFormat("%v", field.desc().default_value_bool());
return absl::StrFormat("%v", field.default_value_bool());
case FieldDescriptor::TYPE_STRING:
case FieldDescriptor::TYPE_BYTES:
return absl::StrFormat(
"b\"%s\"", absl::CHexEscape(field.desc().default_value_string()));
return absl::StrFormat("b\"%s\"",
absl::CHexEscape(field.default_value_string()));
case FieldDescriptor::TYPE_GROUP:
case FieldDescriptor::TYPE_MESSAGE:
case FieldDescriptor::TYPE_ENUM:
ABSL_LOG(FATAL) << "Unsupported field type: " << field.desc().type_name();
ABSL_LOG(FATAL) << "Unsupported field type: " << field.type_name();
}
ABSL_LOG(FATAL) << "unreachable";
}

@ -10,7 +10,6 @@
#include <string>
#include "google/protobuf/compiler/rust/context.h"
#include "google/protobuf/descriptor.h"
namespace google {
@ -23,7 +22,7 @@ namespace rust {
// Both strings and bytes are represented as a byte string literal, i.e. in the
// format `b"default value here"`. It is the caller's responsibility to convert
// the byte literal to an actual string, if needed.
std::string DefaultValue(Context<FieldDescriptor> field);
std::string DefaultValue(const FieldDescriptor& field);
} // namespace rust
} // namespace compiler

@ -17,19 +17,19 @@ namespace protobuf {
namespace compiler {
namespace rust {
void Map::InMsgImpl(Context<FieldDescriptor> field) const {
auto& key_type = *field.desc().message_type()->map_key();
auto& value_type = *field.desc().message_type()->map_value();
void Map::InMsgImpl(Context& ctx, const FieldDescriptor& field) const {
auto& key_type = *field.message_type()->map_key();
auto& value_type = *field.message_type()->map_value();
field.Emit({{"field", field.desc().name()},
{"Key", PrimitiveRsTypeName(key_type)},
{"Value", PrimitiveRsTypeName(value_type)},
{"getter_thunk", Thunk(field, "get")},
{"getter_mut_thunk", Thunk(field, "get_mut")},
{"getter",
[&] {
if (field.is_upb()) {
field.Emit({}, R"rs(
ctx.Emit({{"field", field.name()},
{"Key", PrimitiveRsTypeName(key_type)},
{"Value", PrimitiveRsTypeName(value_type)},
{"getter_thunk", Thunk(ctx, field, "get")},
{"getter_mut_thunk", Thunk(ctx, field, "get_mut")},
{"getter",
[&] {
if (ctx.is_upb()) {
ctx.Emit({}, R"rs(
pub fn r#$field$(&self)
-> $pb$::View<'_, $pb$::Map<$Key$, $Value$>> {
let inner = unsafe {
@ -44,8 +44,8 @@ void Map::InMsgImpl(Context<FieldDescriptor> field) const {
});
$pb$::MapView::from_inner($pbi$::Private, inner)
})rs");
} else {
field.Emit({}, R"rs(
} else {
ctx.Emit({}, R"rs(
pub fn r#$field$(&self)
-> $pb$::View<'_, $pb$::Map<$Key$, $Value$>> {
let inner = $pbr$::MapInner {
@ -55,12 +55,12 @@ void Map::InMsgImpl(Context<FieldDescriptor> field) const {
};
$pb$::MapView::from_inner($pbi$::Private, inner)
})rs");
}
}},
{"getter_mut",
[&] {
if (field.is_upb()) {
field.Emit({}, R"rs(
}
}},
{"getter_mut",
[&] {
if (ctx.is_upb()) {
ctx.Emit({}, R"rs(
pub fn r#$field$_mut(&mut self)
-> $pb$::Mut<'_, $pb$::Map<$Key$, $Value$>> {
let raw = unsafe {
@ -75,8 +75,8 @@ void Map::InMsgImpl(Context<FieldDescriptor> field) const {
};
$pb$::MapMut::from_inner($pbi$::Private, inner)
})rs");
} else {
field.Emit({}, R"rs(
} else {
ctx.Emit({}, R"rs(
pub fn r#$field$_mut(&mut self)
-> $pb$::Mut<'_, $pb$::Map<$Key$, $Value$>> {
let inner = $pbr$::MapInner {
@ -86,30 +86,30 @@ void Map::InMsgImpl(Context<FieldDescriptor> field) const {
};
$pb$::MapMut::from_inner($pbi$::Private, inner)
})rs");
}
}}},
R"rs(
}
}}},
R"rs(
$getter$
$getter_mut$
)rs");
}
void Map::InExternC(Context<FieldDescriptor> field) const {
field.Emit(
void Map::InExternC(Context& ctx, const FieldDescriptor& field) const {
ctx.Emit(
{
{"getter_thunk", Thunk(field, "get")},
{"getter_mut_thunk", Thunk(field, "get_mut")},
{"getter_thunk", Thunk(ctx, field, "get")},
{"getter_mut_thunk", Thunk(ctx, field, "get_mut")},
{"getter",
[&] {
if (field.is_upb()) {
field.Emit({}, R"rs(
if (ctx.is_upb()) {
ctx.Emit({}, R"rs(
fn $getter_thunk$(raw_msg: $pbi$::RawMessage)
-> Option<$pbi$::RawMap>;
fn $getter_mut_thunk$(raw_msg: $pbi$::RawMessage,
arena: $pbi$::RawArena) -> $pbi$::RawMap;
)rs");
} else {
field.Emit({}, R"rs(
ctx.Emit({}, R"rs(
fn $getter_thunk$(msg: $pbi$::RawMessage) -> $pbi$::RawMap;
fn $getter_mut_thunk$(msg: $pbi$::RawMessage,) -> $pbi$::RawMap;
)rs");
@ -121,20 +121,19 @@ void Map::InExternC(Context<FieldDescriptor> field) const {
)rs");
}
void Map::InThunkCc(Context<FieldDescriptor> field) const {
field.Emit(
{{"field", cpp::FieldName(&field.desc())},
{"Key", cpp::PrimitiveTypeName(
field.desc().message_type()->map_key()->cpp_type())},
{"Value", cpp::PrimitiveTypeName(
field.desc().message_type()->map_value()->cpp_type())},
{"QualifiedMsg",
cpp::QualifiedClassName(field.desc().containing_type())},
{"getter_thunk", Thunk(field, "get")},
{"getter_mut_thunk", Thunk(field, "get_mut")},
void Map::InThunkCc(Context& ctx, const FieldDescriptor& field) const {
ctx.Emit(
{{"field", cpp::FieldName(&field)},
{"Key",
cpp::PrimitiveTypeName(field.message_type()->map_key()->cpp_type())},
{"Value",
cpp::PrimitiveTypeName(field.message_type()->map_value()->cpp_type())},
{"QualifiedMsg", cpp::QualifiedClassName(field.containing_type())},
{"getter_thunk", Thunk(ctx, field, "get")},
{"getter_mut_thunk", Thunk(ctx, field, "get_mut")},
{"impls",
[&] {
field.Emit(
ctx.Emit(
R"cc(
const void* $getter_thunk$($QualifiedMsg$& msg) {
return &msg.$field$();

@ -17,15 +17,16 @@ namespace protobuf {
namespace compiler {
namespace rust {
void RepeatedScalar::InMsgImpl(Context<FieldDescriptor> field) const {
field.Emit({{"field", field.desc().name()},
{"Scalar", PrimitiveRsTypeName(field.desc())},
{"getter_thunk", Thunk(field, "get")},
{"getter_mut_thunk", Thunk(field, "get_mut")},
{"getter",
[&] {
if (field.is_upb()) {
field.Emit({}, R"rs(
void RepeatedScalar::InMsgImpl(Context& ctx,
const FieldDescriptor& field) const {
ctx.Emit({{"field", field.name()},
{"Scalar", PrimitiveRsTypeName(field)},
{"getter_thunk", Thunk(ctx, field, "get")},
{"getter_mut_thunk", Thunk(ctx, field, "get_mut")},
{"getter",
[&] {
if (ctx.is_upb()) {
ctx.Emit({}, R"rs(
pub fn r#$field$(&self) -> $pb$::RepeatedView<'_, $Scalar$> {
unsafe {
$getter_thunk$(
@ -40,8 +41,8 @@ void RepeatedScalar::InMsgImpl(Context<FieldDescriptor> field) const {
)
}
)rs");
} else {
field.Emit({}, R"rs(
} else {
ctx.Emit({}, R"rs(
pub fn r#$field$(&self) -> $pb$::RepeatedView<'_, $Scalar$> {
unsafe {
$pb$::RepeatedView::from_raw(
@ -51,13 +52,13 @@ void RepeatedScalar::InMsgImpl(Context<FieldDescriptor> field) const {
}
}
)rs");
}
}},
{"clearer_thunk", Thunk(field, "clear")},
{"field_mutator_getter",
[&] {
if (field.is_upb()) {
field.Emit({}, R"rs(
}
}},
{"clearer_thunk", Thunk(ctx, field, "clear")},
{"field_mutator_getter",
[&] {
if (ctx.is_upb()) {
ctx.Emit({}, R"rs(
pub fn r#$field$_mut(&mut self) -> $pb$::RepeatedMut<'_, $Scalar$> {
unsafe {
$pb$::RepeatedMut::from_inner(
@ -75,8 +76,8 @@ void RepeatedScalar::InMsgImpl(Context<FieldDescriptor> field) const {
}
}
)rs");
} else {
field.Emit({}, R"rs(
} else {
ctx.Emit({}, R"rs(
pub fn r#$field$_mut(&mut self) -> $pb$::RepeatedMut<'_, $Scalar$> {
unsafe {
$pb$::RepeatedMut::from_inner(
@ -89,22 +90,23 @@ void RepeatedScalar::InMsgImpl(Context<FieldDescriptor> field) const {
}
}
)rs");
}
}}},
R"rs(
}
}}},
R"rs(
$getter$
$field_mutator_getter$
)rs");
}
void RepeatedScalar::InExternC(Context<FieldDescriptor> field) const {
field.Emit({{"Scalar", PrimitiveRsTypeName(field.desc())},
{"getter_thunk", Thunk(field, "get")},
{"getter_mut_thunk", Thunk(field, "get_mut")},
{"getter",
[&] {
if (field.is_upb()) {
field.Emit(R"rs(
void RepeatedScalar::InExternC(Context& ctx,
const FieldDescriptor& field) const {
ctx.Emit({{"Scalar", PrimitiveRsTypeName(field)},
{"getter_thunk", Thunk(ctx, field, "get")},
{"getter_mut_thunk", Thunk(ctx, field, "get_mut")},
{"getter",
[&] {
if (ctx.is_upb()) {
ctx.Emit(R"rs(
fn $getter_mut_thunk$(
raw_msg: $pbi$::RawMessage,
size: *const usize,
@ -116,44 +118,44 @@ void RepeatedScalar::InExternC(Context<FieldDescriptor> field) const {
size: *const usize,
) -> Option<$pbi$::RawRepeatedField>;
)rs");
} else {
field.Emit(R"rs(
} else {
ctx.Emit(R"rs(
fn $getter_mut_thunk$(raw_msg: $pbi$::RawMessage) -> $pbi$::RawRepeatedField;
fn $getter_thunk$(raw_msg: $pbi$::RawMessage) -> $pbi$::RawRepeatedField;
)rs");
}
}},
{"clearer_thunk", Thunk(field, "clear")}},
R"rs(
}
}},
{"clearer_thunk", Thunk(ctx, field, "clear")}},
R"rs(
fn $clearer_thunk$(raw_msg: $pbi$::RawMessage);
$getter$
)rs");
}
void RepeatedScalar::InThunkCc(Context<FieldDescriptor> field) const {
field.Emit({{"field", cpp::FieldName(&field.desc())},
{"Scalar", cpp::PrimitiveTypeName(field.desc().cpp_type())},
{"QualifiedMsg",
cpp::QualifiedClassName(field.desc().containing_type())},
{"clearer_thunk", Thunk(field, "clear")},
{"getter_thunk", Thunk(field, "get")},
{"getter_mut_thunk", Thunk(field, "get_mut")},
{"impls",
[&] {
field.Emit(
R"cc(
void $clearer_thunk$($QualifiedMsg$* msg) {
msg->clear_$field$();
}
google::protobuf::RepeatedField<$Scalar$>* $getter_mut_thunk$($QualifiedMsg$* msg) {
return msg->mutable_$field$();
}
const google::protobuf::RepeatedField<$Scalar$>& $getter_thunk$($QualifiedMsg$& msg) {
return msg.$field$();
}
)cc");
}}},
"$impls$");
void RepeatedScalar::InThunkCc(Context& ctx,
const FieldDescriptor& field) const {
ctx.Emit({{"field", cpp::FieldName(&field)},
{"Scalar", cpp::PrimitiveTypeName(field.cpp_type())},
{"QualifiedMsg", cpp::QualifiedClassName(field.containing_type())},
{"clearer_thunk", Thunk(ctx, field, "clear")},
{"getter_thunk", Thunk(ctx, field, "get")},
{"getter_mut_thunk", Thunk(ctx, field, "get_mut")},
{"impls",
[&] {
ctx.Emit(
R"cc(
void $clearer_thunk$($QualifiedMsg$* msg) {
msg->clear_$field$();
}
google::protobuf::RepeatedField<$Scalar$>* $getter_mut_thunk$($QualifiedMsg$* msg) {
return msg->mutable_$field$();
}
const google::protobuf::RepeatedField<$Scalar$>& $getter_thunk$($QualifiedMsg$& msg) {
return msg.$field$();
}
)cc");
}}},
"$impls$");
}
} // namespace rust

@ -17,23 +17,23 @@ namespace protobuf {
namespace compiler {
namespace rust {
void SingularMessage::InMsgImpl(Context<FieldDescriptor> field) const {
Context<Descriptor> d = field.WithDesc(field.desc().message_type());
void SingularMessage::InMsgImpl(Context& ctx,
const FieldDescriptor& field) const {
auto& msg = *field.message_type();
auto prefix = "crate::" + GetCrateRelativeQualifiedPath(ctx, msg);
auto prefix = "crate::" + GetCrateRelativeQualifiedPath(d);
field.Emit(
ctx.Emit(
{
{"prefix", prefix},
{"field", field.desc().name()},
{"getter_thunk", Thunk(field, "get")},
{"getter_mut_thunk", Thunk(field, "get_mut")},
{"clearer_thunk", Thunk(field, "clear")},
{"field", field.name()},
{"getter_thunk", Thunk(ctx, field, "get")},
{"getter_mut_thunk", Thunk(ctx, field, "get_mut")},
{"clearer_thunk", Thunk(ctx, field, "clear")},
{
"view_body",
[&] {
if (field.is_upb()) {
field.Emit({}, R"rs(
if (ctx.is_upb()) {
ctx.Emit({}, R"rs(
let submsg = unsafe { $getter_thunk$(self.inner.msg) };
// For upb, getters return null if the field is unset, so we need
// to check for null and return the default instance manually.
@ -46,7 +46,7 @@ void SingularMessage::InMsgImpl(Context<FieldDescriptor> field) const {
}
)rs");
} else {
field.Emit({}, R"rs(
ctx.Emit({}, R"rs(
// For C++ kernel, getters automatically return the
// default_instance if the field is unset.
let submsg = unsafe { $getter_thunk$(self.inner.msg) };
@ -57,15 +57,15 @@ void SingularMessage::InMsgImpl(Context<FieldDescriptor> field) const {
},
{"submessage_mut",
[&] {
if (field.is_upb()) {
field.Emit({}, R"rs(
if (ctx.is_upb()) {
ctx.Emit({}, R"rs(
let submsg = unsafe {
$getter_mut_thunk$(self.inner.msg, self.inner.arena.raw())
};
$prefix$Mut::new($pbi$::Private, &mut self.inner, submsg)
)rs");
} else {
field.Emit({}, R"rs(
ctx.Emit({}, R"rs(
let submsg = unsafe { $getter_mut_thunk$(self.inner.msg) };
$prefix$Mut::new($pbi$::Private, &mut self.inner, submsg)
)rs");
@ -87,21 +87,22 @@ void SingularMessage::InMsgImpl(Context<FieldDescriptor> field) const {
)rs");
}
void SingularMessage::InExternC(Context<FieldDescriptor> field) const {
field.Emit(
void SingularMessage::InExternC(Context& ctx,
const FieldDescriptor& field) const {
ctx.Emit(
{
{"getter_thunk", Thunk(field, "get")},
{"getter_mut_thunk", Thunk(field, "get_mut")},
{"clearer_thunk", Thunk(field, "clear")},
{"getter_thunk", Thunk(ctx, field, "get")},
{"getter_mut_thunk", Thunk(ctx, field, "get_mut")},
{"clearer_thunk", Thunk(ctx, field, "clear")},
{"getter_mut",
[&] {
if (field.is_cpp()) {
field.Emit(
if (ctx.is_cpp()) {
ctx.Emit(
R"rs(
fn $getter_mut_thunk$(raw_msg: $pbi$::RawMessage)
-> $pbi$::RawMessage;)rs");
} else {
field.Emit(
ctx.Emit(
R"rs(fn $getter_mut_thunk$(raw_msg: $pbi$::RawMessage,
arena: $pbi$::RawArena)
-> $pbi$::RawMessage;)rs");
@ -109,13 +110,13 @@ void SingularMessage::InExternC(Context<FieldDescriptor> field) const {
}},
{"ReturnType",
[&] {
if (field.is_cpp()) {
if (ctx.is_cpp()) {
// guaranteed to have a nonnull submsg for the cpp kernel
field.Emit({}, "$pbi$::RawMessage;");
ctx.Emit({}, "$pbi$::RawMessage;");
} else {
// upb kernel may return NULL for a submsg, we can detect this
// in terra rust if the option returned is None
field.Emit({}, "Option<$pbi$::RawMessage>;");
ctx.Emit({}, "Option<$pbi$::RawMessage>;");
}
}},
},
@ -126,22 +127,22 @@ void SingularMessage::InExternC(Context<FieldDescriptor> field) const {
)rs");
}
void SingularMessage::InThunkCc(Context<FieldDescriptor> field) const {
field.Emit({{"QualifiedMsg",
cpp::QualifiedClassName(field.desc().containing_type())},
{"getter_thunk", Thunk(field, "get")},
{"getter_mut_thunk", Thunk(field, "get_mut")},
{"clearer_thunk", Thunk(field, "clear")},
{"field", cpp::FieldName(&field.desc())}},
R"cc(
const void* $getter_thunk$($QualifiedMsg$* msg) {
return static_cast<const void*>(&msg->$field$());
}
void* $getter_mut_thunk$($QualifiedMsg$* msg) {
return static_cast<void*>(msg->mutable_$field$());
}
void $clearer_thunk$($QualifiedMsg$* msg) { msg->clear_$field$(); }
)cc");
void SingularMessage::InThunkCc(Context& ctx,
const FieldDescriptor& field) const {
ctx.Emit({{"QualifiedMsg", cpp::QualifiedClassName(field.containing_type())},
{"getter_thunk", Thunk(ctx, field, "get")},
{"getter_mut_thunk", Thunk(ctx, field, "get_mut")},
{"clearer_thunk", Thunk(ctx, field, "clear")},
{"field", cpp::FieldName(&field)}},
R"cc(
const void* $getter_thunk$($QualifiedMsg$* msg) {
return static_cast<const void*>(&msg->$field$());
}
void* $getter_mut_thunk$($QualifiedMsg$* msg) {
return static_cast<void*>(msg->mutable_$field$());
}
void $clearer_thunk$($QualifiedMsg$* msg) { msg->clear_$field$(); }
)cc");
}
} // namespace rust

@ -18,16 +18,17 @@ namespace protobuf {
namespace compiler {
namespace rust {
void SingularScalar::InMsgImpl(Context<FieldDescriptor> field) const {
field.Emit(
void SingularScalar::InMsgImpl(Context& ctx,
const FieldDescriptor& field) const {
ctx.Emit(
{
{"field", field.desc().name()},
{"Scalar", PrimitiveRsTypeName(field.desc())},
{"hazzer_thunk", Thunk(field, "has")},
{"field", field.name()},
{"Scalar", PrimitiveRsTypeName(field)},
{"hazzer_thunk", Thunk(ctx, field, "has")},
{"default_value", DefaultValue(field)},
{"getter",
[&] {
field.Emit({}, R"rs(
ctx.Emit({}, R"rs(
pub fn r#$field$(&self) -> $Scalar$ {
unsafe { $getter_thunk$(self.inner.msg) }
}
@ -35,9 +36,9 @@ void SingularScalar::InMsgImpl(Context<FieldDescriptor> field) const {
}},
{"getter_opt",
[&] {
if (!field.desc().is_optional()) return;
if (!field.desc().has_presence()) return;
field.Emit({}, R"rs(
if (!field.is_optional()) return;
if (!field.has_presence()) return;
ctx.Emit({}, R"rs(
pub fn r#$field$_opt(&self) -> $pb$::Optional<$Scalar$> {
if !unsafe { $hazzer_thunk$(self.inner.msg) } {
return $pb$::Optional::Unset($default_value$);
@ -47,13 +48,13 @@ void SingularScalar::InMsgImpl(Context<FieldDescriptor> field) const {
}
)rs");
}},
{"getter_thunk", Thunk(field, "get")},
{"setter_thunk", Thunk(field, "set")},
{"clearer_thunk", Thunk(field, "clear")},
{"getter_thunk", Thunk(ctx, field, "get")},
{"setter_thunk", Thunk(ctx, field, "set")},
{"clearer_thunk", Thunk(ctx, field, "clear")},
{"field_mutator_getter",
[&] {
if (field.desc().has_presence()) {
field.Emit({}, R"rs(
if (field.has_presence()) {
ctx.Emit({}, R"rs(
pub fn r#$field$_mut(&mut self) -> $pb$::FieldEntry<'_, $Scalar$> {
static VTABLE: $pbi$::PrimitiveOptionalMutVTable<$Scalar$> =
$pbi$::PrimitiveOptionalMutVTable::new(
@ -76,7 +77,7 @@ void SingularScalar::InMsgImpl(Context<FieldDescriptor> field) const {
}
)rs");
} else {
field.Emit({}, R"rs(
ctx.Emit({}, R"rs(
pub fn r#$field$_mut(&mut self) -> $pb$::Mut<'_, $Scalar$> {
static VTABLE: $pbi$::PrimitiveVTable<$Scalar$> =
$pbi$::PrimitiveVTable::new(
@ -114,56 +115,57 @@ void SingularScalar::InMsgImpl(Context<FieldDescriptor> field) const {
)rs");
}
void SingularScalar::InExternC(Context<FieldDescriptor> field) const {
field.Emit({{"Scalar", PrimitiveRsTypeName(field.desc())},
{"hazzer_thunk", Thunk(field, "has")},
{"getter_thunk", Thunk(field, "get")},
{"setter_thunk", Thunk(field, "set")},
{"clearer_thunk", Thunk(field, "clear")},
{"hazzer_and_clearer",
[&] {
if (field.desc().has_presence()) {
field.Emit(
R"rs(
void SingularScalar::InExternC(Context& ctx,
const FieldDescriptor& field) const {
ctx.Emit({{"Scalar", PrimitiveRsTypeName(field)},
{"hazzer_thunk", Thunk(ctx, field, "has")},
{"getter_thunk", Thunk(ctx, field, "get")},
{"setter_thunk", Thunk(ctx, field, "set")},
{"clearer_thunk", Thunk(ctx, field, "clear")},
{"hazzer_and_clearer",
[&] {
if (field.has_presence()) {
ctx.Emit(
R"rs(
fn $hazzer_thunk$(raw_msg: $pbi$::RawMessage) -> bool;
fn $clearer_thunk$(raw_msg: $pbi$::RawMessage);
)rs");
}
}}},
R"rs(
}
}}},
R"rs(
$hazzer_and_clearer$
fn $getter_thunk$(raw_msg: $pbi$::RawMessage) -> $Scalar$;
fn $setter_thunk$(raw_msg: $pbi$::RawMessage, val: $Scalar$);
)rs");
}
void SingularScalar::InThunkCc(Context<FieldDescriptor> field) const {
field.Emit({{"field", cpp::FieldName(&field.desc())},
{"Scalar", cpp::PrimitiveTypeName(field.desc().cpp_type())},
{"QualifiedMsg",
cpp::QualifiedClassName(field.desc().containing_type())},
{"hazzer_thunk", Thunk(field, "has")},
{"getter_thunk", Thunk(field, "get")},
{"setter_thunk", Thunk(field, "set")},
{"clearer_thunk", Thunk(field, "clear")},
{"hazzer_and_clearer",
[&] {
if (field.desc().has_presence()) {
field.Emit(R"cc(
bool $hazzer_thunk$($QualifiedMsg$* msg) {
return msg->has_$field$();
}
void $clearer_thunk$($QualifiedMsg$* msg) { msg->clear_$field$(); }
)cc");
}
}}},
R"cc(
$hazzer_and_clearer$;
$Scalar$ $getter_thunk$($QualifiedMsg$* msg) { return msg->$field$(); }
void $setter_thunk$($QualifiedMsg$* msg, $Scalar$ val) {
msg->set_$field$(val);
void SingularScalar::InThunkCc(Context& ctx,
const FieldDescriptor& field) const {
ctx.Emit({{"field", cpp::FieldName(&field)},
{"Scalar", cpp::PrimitiveTypeName(field.cpp_type())},
{"QualifiedMsg", cpp::QualifiedClassName(field.containing_type())},
{"hazzer_thunk", Thunk(ctx, field, "has")},
{"getter_thunk", Thunk(ctx, field, "get")},
{"setter_thunk", Thunk(ctx, field, "set")},
{"clearer_thunk", Thunk(ctx, field, "clear")},
{"hazzer_and_clearer",
[&] {
if (field.has_presence()) {
ctx.Emit(R"cc(
bool $hazzer_thunk$($QualifiedMsg$* msg) {
return msg->has_$field$();
}
void $clearer_thunk$($QualifiedMsg$* msg) { msg->clear_$field$(); }
)cc");
}
)cc");
}}},
R"cc(
$hazzer_and_clearer$;
$Scalar$ $getter_thunk$($QualifiedMsg$* msg) { return msg->$field$(); }
void $setter_thunk$($QualifiedMsg$* msg, $Scalar$ val) {
msg->set_$field$(val);
}
)cc");
}
} // namespace rust

@ -20,24 +20,25 @@ namespace protobuf {
namespace compiler {
namespace rust {
void SingularString::InMsgImpl(Context<FieldDescriptor> field) const {
std::string hazzer_thunk = Thunk(field, "has");
std::string getter_thunk = Thunk(field, "get");
std::string setter_thunk = Thunk(field, "set");
std::string proxied_type = PrimitiveRsTypeName(field.desc());
void SingularString::InMsgImpl(Context& ctx,
const FieldDescriptor& field) const {
std::string hazzer_thunk = Thunk(ctx, field, "has");
std::string getter_thunk = Thunk(ctx, field, "get");
std::string setter_thunk = Thunk(ctx, field, "set");
std::string proxied_type = PrimitiveRsTypeName(field);
auto transform_view = [&] {
if (field.desc().type() == FieldDescriptor::TYPE_STRING) {
field.Emit(R"rs(
if (field.type() == FieldDescriptor::TYPE_STRING) {
ctx.Emit(R"rs(
// SAFETY: The runtime doesn't require ProtoStr to be UTF-8.
unsafe { $pb$::ProtoStr::from_utf8_unchecked(view) }
)rs");
} else {
field.Emit("view");
ctx.Emit("view");
}
};
field.Emit(
ctx.Emit(
{
{"field", field.desc().name()},
{"field", field.name()},
{"hazzer_thunk", hazzer_thunk},
{"getter_thunk", getter_thunk},
{"setter_thunk", setter_thunk},
@ -45,12 +46,12 @@ void SingularString::InMsgImpl(Context<FieldDescriptor> field) const {
{"transform_view", transform_view},
{"field_optional_getter",
[&] {
if (!field.desc().is_optional()) return;
if (!field.desc().has_presence()) return;
field.Emit({{"hazzer_thunk", hazzer_thunk},
{"getter_thunk", getter_thunk},
{"transform_view", transform_view}},
R"rs(
if (!field.is_optional()) return;
if (!field.has_presence()) return;
ctx.Emit({{"hazzer_thunk", hazzer_thunk},
{"getter_thunk", getter_thunk},
{"transform_view", transform_view}},
R"rs(
pub fn $field$_opt(&self) -> $pb$::Optional<&$proxied_type$> {
let view = unsafe { $getter_thunk$(self.inner.msg).as_ref() };
$pb$::Optional::new(
@ -62,30 +63,29 @@ void SingularString::InMsgImpl(Context<FieldDescriptor> field) const {
}},
{"field_mutator_getter",
[&] {
if (field.desc().has_presence()) {
field.Emit(
if (field.has_presence()) {
ctx.Emit(
{
{"field", field.desc().name()},
{"field", field.name()},
{"proxied_type", proxied_type},
{"default_val", DefaultValue(field)},
{"view_type", proxied_type},
{"transform_field_entry",
[&] {
if (field.desc().type() ==
FieldDescriptor::TYPE_STRING) {
field.Emit(R"rs(
if (field.type() == FieldDescriptor::TYPE_STRING) {
ctx.Emit(R"rs(
$pb$::ProtoStrMut::field_entry_from_bytes(
$pbi$::Private, out
)
)rs");
} else {
field.Emit("out");
ctx.Emit("out");
}
}},
{"hazzer_thunk", hazzer_thunk},
{"getter_thunk", getter_thunk},
{"setter_thunk", setter_thunk},
{"clearer_thunk", Thunk(field, "clear")},
{"clearer_thunk", Thunk(ctx, field, "clear")},
},
R"rs(
pub fn $field$_mut(&mut self) -> $pb$::FieldEntry<'_, $proxied_type$> {
@ -112,11 +112,11 @@ void SingularString::InMsgImpl(Context<FieldDescriptor> field) const {
}
)rs");
} else {
field.Emit({{"field", field.desc().name()},
{"proxied_type", proxied_type},
{"getter_thunk", getter_thunk},
{"setter_thunk", setter_thunk}},
R"rs(
ctx.Emit({{"field", field.name()},
{"proxied_type", proxied_type},
{"getter_thunk", getter_thunk},
{"setter_thunk", setter_thunk}},
R"rs(
pub fn $field$_mut(&mut self) -> $pb$::Mut<'_, $proxied_type$> {
static VTABLE: $pbi$::BytesMutVTable = unsafe {
$pbi$::BytesMutVTable::new(
@ -152,20 +152,21 @@ void SingularString::InMsgImpl(Context<FieldDescriptor> field) const {
)rs");
}
void SingularString::InExternC(Context<FieldDescriptor> field) const {
field.Emit({{"hazzer_thunk", Thunk(field, "has")},
{"getter_thunk", Thunk(field, "get")},
{"setter_thunk", Thunk(field, "set")},
{"clearer_thunk", Thunk(field, "clear")},
{"hazzer",
[&] {
if (field.desc().has_presence()) {
field.Emit(R"rs(
void SingularString::InExternC(Context& ctx,
const FieldDescriptor& field) const {
ctx.Emit({{"hazzer_thunk", Thunk(ctx, field, "has")},
{"getter_thunk", Thunk(ctx, field, "get")},
{"setter_thunk", Thunk(ctx, field, "set")},
{"clearer_thunk", Thunk(ctx, field, "clear")},
{"hazzer",
[&] {
if (field.has_presence()) {
ctx.Emit(R"rs(
fn $hazzer_thunk$(raw_msg: $pbi$::RawMessage) -> bool;
)rs");
}
}}},
R"rs(
}
}}},
R"rs(
$hazzer$
fn $getter_thunk$(raw_msg: $pbi$::RawMessage) -> $pbi$::PtrAndLen;
fn $setter_thunk$(raw_msg: $pbi$::RawMessage, val: $pbi$::PtrAndLen);
@ -173,35 +174,35 @@ void SingularString::InExternC(Context<FieldDescriptor> field) const {
)rs");
}
void SingularString::InThunkCc(Context<FieldDescriptor> field) const {
field.Emit({{"field", cpp::FieldName(&field.desc())},
{"QualifiedMsg",
cpp::QualifiedClassName(field.desc().containing_type())},
{"hazzer_thunk", Thunk(field, "has")},
{"getter_thunk", Thunk(field, "get")},
{"setter_thunk", Thunk(field, "set")},
{"clearer_thunk", Thunk(field, "clear")},
{"hazzer",
[&] {
if (field.desc().has_presence()) {
field.Emit(R"cc(
bool $hazzer_thunk$($QualifiedMsg$* msg) {
return msg->has_$field$();
}
void $clearer_thunk$($QualifiedMsg$* msg) { msg->clear_$field$(); }
)cc");
}
}}},
R"cc(
$hazzer$;
::google::protobuf::rust_internal::PtrAndLen $getter_thunk$($QualifiedMsg$* msg) {
absl::string_view val = msg->$field$();
return ::google::protobuf::rust_internal::PtrAndLen(val.data(), val.size());
}
void $setter_thunk$($QualifiedMsg$* msg, ::google::protobuf::rust_internal::PtrAndLen s) {
msg->set_$field$(absl::string_view(s.ptr, s.len));
void SingularString::InThunkCc(Context& ctx,
const FieldDescriptor& field) const {
ctx.Emit({{"field", cpp::FieldName(&field)},
{"QualifiedMsg", cpp::QualifiedClassName(field.containing_type())},
{"hazzer_thunk", Thunk(ctx, field, "has")},
{"getter_thunk", Thunk(ctx, field, "get")},
{"setter_thunk", Thunk(ctx, field, "set")},
{"clearer_thunk", Thunk(ctx, field, "clear")},
{"hazzer",
[&] {
if (field.has_presence()) {
ctx.Emit(R"cc(
bool $hazzer_thunk$($QualifiedMsg$* msg) {
return msg->has_$field$();
}
void $clearer_thunk$($QualifiedMsg$* msg) { msg->clear_$field$(); }
)cc");
}
)cc");
}}},
R"cc(
$hazzer$;
::google::protobuf::rust_internal::PtrAndLen $getter_thunk$($QualifiedMsg$* msg) {
absl::string_view val = msg->$field$();
return ::google::protobuf::rust_internal::PtrAndLen(val.data(), val.size());
}
void $setter_thunk$($QualifiedMsg$* msg, ::google::protobuf::rust_internal::PtrAndLen s) {
msg->set_$field$(absl::string_view(s.ptr, s.len));
}
)cc");
}
} // namespace rust

@ -15,11 +15,12 @@ namespace protobuf {
namespace compiler {
namespace rust {
void UnsupportedField::InMsgImpl(Context<FieldDescriptor> field) const {
field.Emit({{"reason", reason_}}, R"rs(
void UnsupportedField::InMsgImpl(Context& ctx,
const FieldDescriptor& field) const {
ctx.Emit({{"reason", reason_}}, R"rs(
// Unsupported! :( Reason: $reason$
)rs");
field.printer().PrintRaw("\n");
ctx.printer().PrintRaw("\n");
}
} // namespace rust

@ -68,13 +68,12 @@ absl::StatusOr<Options> Options::Parse(absl::string_view param) {
return opts;
}
bool IsInCurrentlyGeneratingCrate(Context<FileDescriptor> file) {
return file.generator_context().is_file_in_current_crate(&file.desc());
bool IsInCurrentlyGeneratingCrate(Context& ctx, const FileDescriptor& file) {
return ctx.generator_context().is_file_in_current_crate(file);
}
bool IsInCurrentlyGeneratingCrate(Context<Descriptor> message) {
return message.generator_context().is_file_in_current_crate(
message.desc().file());
bool IsInCurrentlyGeneratingCrate(Context& ctx, const Descriptor& message) {
return IsInCurrentlyGeneratingCrate(ctx, *message.file());
}
} // namespace rust

@ -53,14 +53,14 @@ class RustGeneratorContext {
const std::vector<const FileDescriptor*>* files_in_current_crate)
: files_in_current_crate_(*files_in_current_crate) {}
const FileDescriptor* primary_file() const {
return files_in_current_crate_.front();
const FileDescriptor& primary_file() const {
return *files_in_current_crate_.front();
}
bool is_file_in_current_crate(const FileDescriptor* f) const {
bool is_file_in_current_crate(const FileDescriptor& f) const {
return std::find(files_in_current_crate_.begin(),
files_in_current_crate_.end(),
f) != files_in_current_crate_.end();
&f) != files_in_current_crate_.end();
}
private:
@ -68,26 +68,20 @@ class RustGeneratorContext {
};
// A context for generating a particular kind of definition.
// This type acts as an options struct (as in go/totw/173) for most of the
// generator.
//
// `Descriptor` is the type of a descriptor.h class relevant for the current
// context.
template <typename Descriptor>
class Context {
public:
Context(const Options* opts, const Descriptor* desc,
Context(const Options* opts,
const RustGeneratorContext* rust_generator_context,
io::Printer* printer)
: opts_(opts),
desc_(desc),
rust_generator_context_(rust_generator_context),
printer_(printer) {}
Context(const Context&) = default;
Context& operator=(const Context&) = default;
Context(const Context&) = delete;
Context& operator=(const Context&) = delete;
Context(Context&&) = default;
Context& operator=(Context&&) = default;
const Descriptor& desc() const { return *desc_; }
const Options& opts() const { return *opts_; }
const RustGeneratorContext& generator_context() const {
return *rust_generator_context_;
@ -99,19 +93,8 @@ class Context {
// NOTE: prefer ctx.Emit() over ctx.printer().Emit();
io::Printer& printer() const { return *printer_; }
// Creates a new context over a different descriptor.
template <typename D>
Context<D> WithDesc(const D& desc) const {
return Context<D>(opts_, &desc, rust_generator_context_, printer_);
}
template <typename D>
Context<D> WithDesc(const D* desc) const {
return Context<D>(opts_, desc, rust_generator_context_, printer_);
}
Context WithPrinter(io::Printer* printer) const {
return Context(opts_, desc_, rust_generator_context_, printer);
return Context(opts_, rust_generator_context_, printer);
}
// Forwards to Emit(), which will likely be called all the time.
@ -128,13 +111,12 @@ class Context {
private:
const Options* opts_;
const Descriptor* desc_;
const RustGeneratorContext* rust_generator_context_;
io::Printer* printer_;
};
bool IsInCurrentlyGeneratingCrate(Context<FileDescriptor> file);
bool IsInCurrentlyGeneratingCrate(Context<Descriptor> message);
bool IsInCurrentlyGeneratingCrate(Context& ctx, const FileDescriptor& file);
bool IsInCurrentlyGeneratingCrate(Context& ctx, const Descriptor& message);
} // namespace rust
} // namespace compiler

@ -48,12 +48,11 @@ namespace {
// pub mod submodule {
// pub mod separator {
// ```
void EmitOpeningOfPackageModules(absl::string_view pkg,
Context<FileDescriptor> file) {
void EmitOpeningOfPackageModules(Context& ctx, absl::string_view pkg) {
if (pkg.empty()) return;
for (absl::string_view segment : absl::StrSplit(pkg, '.')) {
file.Emit({{"segment", segment}},
R"rs(
ctx.Emit({{"segment", segment}},
R"rs(
pub mod $segment$ {
)rs");
}
@ -70,14 +69,13 @@ void EmitOpeningOfPackageModules(absl::string_view pkg,
// } // mod uses
// } // mod package
// ```
void EmitClosingOfPackageModules(absl::string_view pkg,
Context<FileDescriptor> file) {
void EmitClosingOfPackageModules(Context& ctx, absl::string_view pkg) {
if (pkg.empty()) return;
std::vector<absl::string_view> segments = absl::StrSplit(pkg, '.');
absl::c_reverse(segments);
for (absl::string_view segment : segments) {
file.Emit({{"segment", segment}}, R"rs(
ctx.Emit({{"segment", segment}}, R"rs(
} // mod $segment$
)rs");
}
@ -87,14 +85,13 @@ void EmitClosingOfPackageModules(absl::string_view pkg,
// `non_primary_src` into the `primary_file`.
//
// `non_primary_src` has to be a non-primary src of the current `proto_library`.
void EmitPubUseOfOwnMessages(Context<FileDescriptor>& primary_file,
const Context<FileDescriptor>& non_primary_src) {
for (int i = 0; i < non_primary_src.desc().message_type_count(); ++i) {
auto msg = primary_file.WithDesc(non_primary_src.desc().message_type(i));
auto mod = RustInternalModuleName(non_primary_src);
auto name = msg.desc().name();
primary_file.Emit({{"mod", mod}, {"Msg", name}},
R"rs(
void EmitPubUseOfOwnMessages(Context& ctx, const FileDescriptor& primary_file,
const FileDescriptor& non_primary_src) {
for (int i = 0; i < non_primary_src.message_type_count(); ++i) {
auto& msg = *non_primary_src.message_type(i);
auto mod = RustInternalModuleName(ctx, non_primary_src);
ctx.Emit({{"mod", mod}, {"Msg", msg.name()}},
R"rs(
pub use crate::$mod$::$Msg$;
// TODO Address use for imported crates
pub use crate::$mod$::$Msg$View;
@ -109,14 +106,15 @@ void EmitPubUseOfOwnMessages(Context<FileDescriptor>& primary_file,
//
// `dep` is a primary src of a dependency of the current `proto_library`.
// TODO: Add support for public import of non-primary srcs of deps.
void EmitPubUseForImportedMessages(Context<FileDescriptor>& primary_file,
const Context<FileDescriptor>& dep) {
std::string crate_name = GetCrateName(dep);
for (int i = 0; i < dep.desc().message_type_count(); ++i) {
auto msg = primary_file.WithDesc(dep.desc().message_type(i));
auto path = GetCrateRelativeQualifiedPath(msg);
primary_file.Emit({{"crate", crate_name}, {"pkg::Msg", path}},
R"rs(
void EmitPubUseForImportedMessages(Context& ctx,
const FileDescriptor& primary_file,
const FileDescriptor& dep) {
std::string crate_name = GetCrateName(ctx, dep);
for (int i = 0; i < dep.message_type_count(); ++i) {
auto& msg = *dep.message_type(i);
auto path = GetCrateRelativeQualifiedPath(ctx, msg);
ctx.Emit({{"crate", crate_name}, {"pkg::Msg", path}},
R"rs(
pub use $crate$::$pkg::Msg$;
pub use $crate$::$pkg::Msg$View;
)rs");
@ -124,9 +122,9 @@ void EmitPubUseForImportedMessages(Context<FileDescriptor>& primary_file,
}
// Emits all public imports of the current file
void EmitPublicImports(Context<FileDescriptor>& primary_file) {
for (int i = 0; i < primary_file.desc().public_dependency_count(); ++i) {
auto dep_file = primary_file.desc().public_dependency(i);
void EmitPublicImports(Context& ctx, const FileDescriptor& primary_file) {
for (int i = 0; i < primary_file.public_dependency_count(); ++i) {
auto& dep_file = *primary_file.public_dependency(i);
// If the publicly imported file is a src of the current `proto_library`
// we don't need to emit `pub use` here, we already do it for all srcs in
// RustGenerator::Generate. In other words, all srcs are implicitly publicly
@ -134,30 +132,29 @@ void EmitPublicImports(Context<FileDescriptor>& primary_file) {
// TODO: Handle the case where a non-primary src with the same
// declared package as the primary src publicly imports a file that the
// primary doesn't.
auto dep = primary_file.WithDesc(dep_file);
if (IsInCurrentlyGeneratingCrate(dep)) {
if (IsInCurrentlyGeneratingCrate(ctx, dep_file)) {
return;
}
EmitPubUseForImportedMessages(primary_file, dep);
EmitPubUseForImportedMessages(ctx, primary_file, dep_file);
}
}
// Emits submodule declarations so `rustc` can find non primary sources from the
// primary file.
void DeclareSubmodulesForNonPrimarySrcs(
Context<FileDescriptor>& primary_file,
absl::Span<const Context<FileDescriptor>> non_primary_srcs) {
std::string primary_file_path = GetRsFile(primary_file);
Context& ctx, const FileDescriptor& primary_file,
absl::Span<const FileDescriptor* const> non_primary_srcs) {
std::string primary_file_path = GetRsFile(ctx, primary_file);
RelativePath primary_relpath(primary_file_path);
for (const auto& non_primary_src : non_primary_srcs) {
std::string non_primary_file_path = GetRsFile(non_primary_src);
for (const FileDescriptor* non_primary_src : non_primary_srcs) {
std::string non_primary_file_path = GetRsFile(ctx, *non_primary_src);
std::string relative_mod_path =
primary_relpath.Relative(RelativePath(non_primary_file_path));
primary_file.Emit({{"file_path", relative_mod_path},
{"foo", primary_file_path},
{"bar", non_primary_file_path},
{"mod_name", RustInternalModuleName(non_primary_src)}},
R"rs(
ctx.Emit({{"file_path", relative_mod_path},
{"foo", primary_file_path},
{"bar", non_primary_file_path},
{"mod_name", RustInternalModuleName(ctx, *non_primary_src)}},
R"rs(
#[path="$file_path$"]
pub mod $mod_name$;
)rs");
@ -169,33 +166,32 @@ void DeclareSubmodulesForNonPrimarySrcs(
//
// Returns the non-primary sources that should be reexported from the package of
// the primary file.
std::vector<const Context<FileDescriptor>*> ReexportMessagesFromSubmodules(
Context<FileDescriptor>& primary_file,
absl::Span<const Context<FileDescriptor>> non_primary_srcs) {
absl::btree_map<absl::string_view,
std::vector<const Context<FileDescriptor>*>>
std::vector<const FileDescriptor*> ReexportMessagesFromSubmodules(
Context& ctx, const FileDescriptor& primary_file,
absl::Span<const FileDescriptor* const> non_primary_srcs) {
absl::btree_map<absl::string_view, std::vector<const FileDescriptor*>>
packages;
for (const Context<FileDescriptor>& ctx : non_primary_srcs) {
packages[ctx.desc().package()].push_back(&ctx);
for (const FileDescriptor* file : non_primary_srcs) {
packages[file->package()].push_back(file);
}
for (const auto& pair : packages) {
// We will deal with messages for the package of the primary file later.
auto fds = pair.second;
absl::string_view package = fds[0]->desc().package();
if (package == primary_file.desc().package()) continue;
absl::string_view package = fds[0]->package();
if (package == primary_file.package()) continue;
EmitOpeningOfPackageModules(package, primary_file);
for (const Context<FileDescriptor>* c : fds) {
EmitPubUseOfOwnMessages(primary_file, *c);
EmitOpeningOfPackageModules(ctx, package);
for (const FileDescriptor* c : fds) {
EmitPubUseOfOwnMessages(ctx, primary_file, *c);
}
EmitClosingOfPackageModules(package, primary_file);
EmitClosingOfPackageModules(ctx, package);
}
return packages[primary_file.desc().package()];
return packages[primary_file.package()];
}
} // namespace
bool RustGenerator::Generate(const FileDescriptor* file_desc,
bool RustGenerator::Generate(const FileDescriptor* file,
const std::string& parameter,
GeneratorContext* generator_context,
std::string* error) const {
@ -210,15 +206,15 @@ bool RustGenerator::Generate(const FileDescriptor* file_desc,
RustGeneratorContext rust_generator_context(&files_in_current_crate);
Context<FileDescriptor> file(&*opts, file_desc, &rust_generator_context,
nullptr);
Context ctx_without_printer(&*opts, &rust_generator_context, nullptr);
auto outfile = absl::WrapUnique(generator_context->Open(GetRsFile(file)));
auto outfile = absl::WrapUnique(
generator_context->Open(GetRsFile(ctx_without_printer, *file)));
io::Printer printer(outfile.get());
file = file.WithPrinter(&printer);
Context ctx = ctx_without_printer.WithPrinter(&printer);
// Convenience shorthands for common symbols.
auto v = file.printer().WithVars({
auto v = ctx.printer().WithVars({
{"std", "::__std"},
{"pb", "::__pb"},
{"pbi", "::__pb::__internal"},
@ -227,67 +223,66 @@ bool RustGenerator::Generate(const FileDescriptor* file_desc,
{"Phantom", "::__std::marker::PhantomData"},
});
file.Emit({{"kernel", KernelRsName(file.opts().kernel)}}, R"rs(
ctx.Emit({{"kernel", KernelRsName(ctx.opts().kernel)}}, R"rs(
extern crate protobuf_$kernel$ as __pb;
extern crate std as __std;
)rs");
std::vector<Context<FileDescriptor>> file_contexts;
std::vector<const FileDescriptor*> file_contexts;
for (const FileDescriptor* f : files_in_current_crate) {
file_contexts.push_back(file.WithDesc(*f));
file_contexts.push_back(f);
}
// Generating the primary file?
if (file_desc == rust_generator_context.primary_file()) {
if (file == &rust_generator_context.primary_file()) {
auto non_primary_srcs = absl::MakeConstSpan(file_contexts).subspan(1);
DeclareSubmodulesForNonPrimarySrcs(file, non_primary_srcs);
DeclareSubmodulesForNonPrimarySrcs(ctx, *file, non_primary_srcs);
std::vector<const Context<FileDescriptor>*>
non_primary_srcs_in_primary_package =
ReexportMessagesFromSubmodules(file, non_primary_srcs);
std::vector<const FileDescriptor*> non_primary_srcs_in_primary_package =
ReexportMessagesFromSubmodules(ctx, *file, non_primary_srcs);
EmitOpeningOfPackageModules(file.desc().package(), file);
EmitOpeningOfPackageModules(ctx, file->package());
for (const Context<FileDescriptor>* non_primary_file :
for (const FileDescriptor* non_primary_file :
non_primary_srcs_in_primary_package) {
EmitPubUseOfOwnMessages(file, *non_primary_file);
EmitPubUseOfOwnMessages(ctx, *file, *non_primary_file);
}
}
EmitPublicImports(file);
EmitPublicImports(ctx, *file);
std::unique_ptr<io::ZeroCopyOutputStream> thunks_cc;
std::unique_ptr<io::Printer> thunks_printer;
if (file.is_cpp()) {
thunks_cc.reset(generator_context->Open(GetThunkCcFile(file)));
if (ctx.is_cpp()) {
thunks_cc.reset(generator_context->Open(GetThunkCcFile(ctx, *file)));
thunks_printer = std::make_unique<io::Printer>(thunks_cc.get());
thunks_printer->Emit({{"proto_h", GetHeaderFile(file)}},
thunks_printer->Emit({{"proto_h", GetHeaderFile(ctx, *file)}},
R"cc(
#include "$proto_h$"
#include "google/protobuf/rust/cpp_kernel/cpp_api.h"
)cc");
}
for (int i = 0; i < file.desc().message_type_count(); ++i) {
auto msg = file.WithDesc(file.desc().message_type(i));
for (int i = 0; i < file->message_type_count(); ++i) {
auto& msg = *file->message_type(i);
GenerateRs(msg);
msg.printer().PrintRaw("\n");
GenerateRs(ctx, msg);
ctx.printer().PrintRaw("\n");
if (file.is_cpp()) {
auto thunks_msg = msg.WithPrinter(thunks_printer.get());
if (ctx.is_cpp()) {
auto thunks_ctx = ctx.WithPrinter(thunks_printer.get());
thunks_msg.Emit({{"Msg", msg.desc().full_name()}}, R"cc(
thunks_ctx.Emit({{"Msg", msg.full_name()}}, R"cc(
// $Msg$
)cc");
GenerateThunksCc(thunks_msg);
thunks_msg.printer().PrintRaw("\n");
GenerateThunksCc(thunks_ctx, msg);
thunks_ctx.printer().PrintRaw("\n");
}
}
if (file_desc == files_in_current_crate.front()) {
EmitClosingOfPackageModules(file.desc().package(), file);
if (file == files_in_current_crate.front()) {
EmitClosingOfPackageModules(ctx, file->package());
}
return true;
}

@ -24,16 +24,16 @@ namespace compiler {
namespace rust {
namespace {
void MessageNew(Context<Descriptor> msg) {
switch (msg.opts().kernel) {
void MessageNew(Context& ctx, const Descriptor& msg) {
switch (ctx.opts().kernel) {
case Kernel::kCpp:
msg.Emit({{"new_thunk", Thunk(msg, "new")}}, R"rs(
ctx.Emit({{"new_thunk", Thunk(ctx, msg, "new")}}, R"rs(
Self { inner: $pbr$::MessageInner { msg: unsafe { $new_thunk$() } } }
)rs");
return;
case Kernel::kUpb:
msg.Emit({{"new_thunk", Thunk(msg, "new")}}, R"rs(
ctx.Emit({{"new_thunk", Thunk(ctx, msg, "new")}}, R"rs(
let arena = $pbr$::Arena::new();
Self {
inner: $pbr$::MessageInner {
@ -48,16 +48,16 @@ void MessageNew(Context<Descriptor> msg) {
ABSL_LOG(FATAL) << "unreachable";
}
void MessageSerialize(Context<Descriptor> msg) {
switch (msg.opts().kernel) {
void MessageSerialize(Context& ctx, const Descriptor& msg) {
switch (ctx.opts().kernel) {
case Kernel::kCpp:
msg.Emit({{"serialize_thunk", Thunk(msg, "serialize")}}, R"rs(
ctx.Emit({{"serialize_thunk", Thunk(ctx, msg, "serialize")}}, R"rs(
unsafe { $serialize_thunk$(self.inner.msg) }
)rs");
return;
case Kernel::kUpb:
msg.Emit({{"serialize_thunk", Thunk(msg, "serialize")}}, R"rs(
ctx.Emit({{"serialize_thunk", Thunk(ctx, msg, "serialize")}}, R"rs(
let arena = $pbr$::Arena::new();
let mut len = 0;
unsafe {
@ -71,12 +71,12 @@ void MessageSerialize(Context<Descriptor> msg) {
ABSL_LOG(FATAL) << "unreachable";
}
void MessageDeserialize(Context<Descriptor> msg) {
switch (msg.opts().kernel) {
void MessageDeserialize(Context& ctx, const Descriptor& msg) {
switch (ctx.opts().kernel) {
case Kernel::kCpp:
msg.Emit(
ctx.Emit(
{
{"deserialize_thunk", Thunk(msg, "deserialize")},
{"deserialize_thunk", Thunk(ctx, msg, "deserialize")},
},
R"rs(
let success = unsafe {
@ -92,7 +92,7 @@ void MessageDeserialize(Context<Descriptor> msg) {
return;
case Kernel::kUpb:
msg.Emit({{"deserialize_thunk", Thunk(msg, "parse")}}, R"rs(
ctx.Emit({{"deserialize_thunk", Thunk(ctx, msg, "parse")}}, R"rs(
let arena = $pbr$::Arena::new();
let msg = unsafe {
$deserialize_thunk$(data.as_ptr(), data.len(), arena.raw())
@ -115,15 +115,15 @@ void MessageDeserialize(Context<Descriptor> msg) {
ABSL_LOG(FATAL) << "unreachable";
}
void MessageExterns(Context<Descriptor> msg) {
switch (msg.opts().kernel) {
void MessageExterns(Context& ctx, const Descriptor& msg) {
switch (ctx.opts().kernel) {
case Kernel::kCpp:
msg.Emit(
ctx.Emit(
{
{"new_thunk", Thunk(msg, "new")},
{"delete_thunk", Thunk(msg, "delete")},
{"serialize_thunk", Thunk(msg, "serialize")},
{"deserialize_thunk", Thunk(msg, "deserialize")},
{"new_thunk", Thunk(ctx, msg, "new")},
{"delete_thunk", Thunk(ctx, msg, "delete")},
{"serialize_thunk", Thunk(ctx, msg, "serialize")},
{"deserialize_thunk", Thunk(ctx, msg, "deserialize")},
},
R"rs(
fn $new_thunk$() -> $pbi$::RawMessage;
@ -134,11 +134,11 @@ void MessageExterns(Context<Descriptor> msg) {
return;
case Kernel::kUpb:
msg.Emit(
ctx.Emit(
{
{"new_thunk", Thunk(msg, "new")},
{"serialize_thunk", Thunk(msg, "serialize")},
{"deserialize_thunk", Thunk(msg, "parse")},
{"new_thunk", Thunk(ctx, msg, "new")},
{"serialize_thunk", Thunk(ctx, msg, "serialize")},
{"deserialize_thunk", Thunk(ctx, msg, "parse")},
},
R"rs(
fn $new_thunk$(arena: $pbi$::RawArena) -> $pbi$::RawMessage;
@ -151,36 +151,37 @@ void MessageExterns(Context<Descriptor> msg) {
ABSL_LOG(FATAL) << "unreachable";
}
void MessageDrop(Context<Descriptor> msg) {
if (msg.is_upb()) {
void MessageDrop(Context& ctx, const Descriptor& msg) {
if (ctx.is_upb()) {
// Nothing to do here; drop glue (which will run drop(self.arena)
// automatically) is sufficient.
return;
}
msg.Emit({{"delete_thunk", Thunk(msg, "delete")}}, R"rs(
ctx.Emit({{"delete_thunk", Thunk(ctx, msg, "delete")}}, R"rs(
unsafe { $delete_thunk$(self.inner.msg); }
)rs");
}
void GetterForViewOrMut(Context<FieldDescriptor> field, bool is_mut) {
auto fieldName = field.desc().name();
auto fieldType = field.desc().type();
auto getter_thunk = Thunk(field, "get");
auto setter_thunk = Thunk(field, "set");
auto clearer_thunk = Thunk(field, "clear");
void GetterForViewOrMut(Context& ctx, const FieldDescriptor& field,
bool is_mut) {
auto fieldName = field.name();
auto fieldType = field.type();
auto getter_thunk = Thunk(ctx, field, "get");
auto setter_thunk = Thunk(ctx, field, "set");
auto clearer_thunk = Thunk(ctx, field, "clear");
// If we're dealing with a Mut, the getter must be supplied
// self.inner.msg() whereas a View has to be supplied self.msg
auto self = is_mut ? "self.inner.msg()" : "self.msg";
if (fieldType == FieldDescriptor::TYPE_MESSAGE) {
Context<Descriptor> d = field.WithDesc(field.desc().message_type());
const Descriptor& msg = *field.message_type();
// TODO: support messages which are defined in other crates.
if (!IsInCurrentlyGeneratingCrate(d)) {
if (!IsInCurrentlyGeneratingCrate(ctx, msg)) {
return;
}
auto prefix = "crate::" + GetCrateRelativeQualifiedPath(d);
field.Emit(
auto prefix = "crate::" + GetCrateRelativeQualifiedPath(ctx, msg);
ctx.Emit(
{
{"prefix", prefix},
{"field", fieldName},
@ -190,8 +191,8 @@ void GetterForViewOrMut(Context<FieldDescriptor> field, bool is_mut) {
{
"view_body",
[&] {
if (field.is_upb()) {
field.Emit({}, R"rs(
if (ctx.is_upb()) {
ctx.Emit({}, R"rs(
let submsg = unsafe { $getter_thunk$($self$) };
match submsg {
None => $prefix$View::new($pbi$::Private,
@ -200,7 +201,7 @@ void GetterForViewOrMut(Context<FieldDescriptor> field, bool is_mut) {
}
)rs");
} else {
field.Emit({}, R"rs(
ctx.Emit({}, R"rs(
let submsg = unsafe { $getter_thunk$($self$) };
$prefix$View::new($pbi$::Private, submsg)
)rs");
@ -216,18 +217,18 @@ void GetterForViewOrMut(Context<FieldDescriptor> field, bool is_mut) {
return;
}
auto rsType = PrimitiveRsTypeName(field.desc());
auto rsType = PrimitiveRsTypeName(field);
if (fieldType == FieldDescriptor::TYPE_STRING ||
fieldType == FieldDescriptor::TYPE_BYTES) {
field.Emit({{"field", fieldName},
{"self", self},
{"getter_thunk", getter_thunk},
{"setter_thunk", setter_thunk},
{"RsType", rsType},
{"maybe_mutator",
[&] {
if (is_mut) {
field.Emit({}, R"rs(
ctx.Emit({{"field", fieldName},
{"self", self},
{"getter_thunk", getter_thunk},
{"setter_thunk", setter_thunk},
{"RsType", rsType},
{"maybe_mutator",
[&] {
if (is_mut) {
ctx.Emit({}, R"rs(
pub fn r#$field$_mut(&self) -> $pb$::Mut<'_, $RsType$> {
static VTABLE: $pbi$::BytesMutVTable =
$pbi$::BytesMutVTable::new(
@ -248,9 +249,9 @@ void GetterForViewOrMut(Context<FieldDescriptor> field, bool is_mut) {
}
}
)rs");
}
}}},
R"rs(
}
}}},
R"rs(
pub fn r#$field$(&self) -> $pb$::View<'_, $RsType$> {
let s = unsafe { $getter_thunk$($self$).as_ref() };
unsafe { __pb::ProtoStr::from_utf8_unchecked(s).into() }
@ -259,19 +260,19 @@ void GetterForViewOrMut(Context<FieldDescriptor> field, bool is_mut) {
$maybe_mutator$
)rs");
} else {
field.Emit({{"field", fieldName},
{"getter_thunk", getter_thunk},
{"setter_thunk", setter_thunk},
{"clearer_thunk", clearer_thunk},
{"self", self},
{"RsType", rsType},
{"maybe_mutator",
[&] {
// TODO: once the rust public api is accessible,
// by tooling, ensure that this only appears for the
// mutational pathway
if (is_mut && fieldType) {
field.Emit({}, R"rs(
ctx.Emit({{"field", fieldName},
{"getter_thunk", getter_thunk},
{"setter_thunk", setter_thunk},
{"clearer_thunk", clearer_thunk},
{"self", self},
{"RsType", rsType},
{"maybe_mutator",
[&] {
// TODO: once the rust public api is accessible,
// by tooling, ensure that this only appears for the
// mutational pathway
if (is_mut && fieldType) {
ctx.Emit({}, R"rs(
pub fn r#$field$_mut(&self) -> $pb$::Mut<'_, $RsType$> {
static VTABLE: $pbi$::PrimitiveVTable<$RsType$> =
$pbi$::PrimitiveVTable::new(
@ -290,9 +291,9 @@ void GetterForViewOrMut(Context<FieldDescriptor> field, bool is_mut) {
}
}
)rs");
}
}}},
R"rs(
}
}}},
R"rs(
pub fn r#$field$(&self) -> $pb$::View<'_, $RsType$> {
unsafe { $getter_thunk$($self$) }
}
@ -302,93 +303,89 @@ void GetterForViewOrMut(Context<FieldDescriptor> field, bool is_mut) {
}
}
void AccessorsForViewOrMut(Context<Descriptor> msg, bool is_mut) {
for (int i = 0; i < msg.desc().field_count(); ++i) {
auto field = msg.WithDesc(*msg.desc().field(i));
if (field.desc().is_repeated()) continue;
void AccessorsForViewOrMut(Context& ctx, const Descriptor& msg, bool is_mut) {
for (int i = 0; i < msg.field_count(); ++i) {
const FieldDescriptor& field = *msg.field(i);
if (field.is_repeated()) continue;
// TODO - add cord support
if (field.desc().options().has_ctype()) continue;
if (field.options().has_ctype()) continue;
// TODO
if (field.desc().type() == FieldDescriptor::TYPE_ENUM ||
field.desc().type() == FieldDescriptor::TYPE_GROUP)
if (field.type() == FieldDescriptor::TYPE_ENUM ||
field.type() == FieldDescriptor::TYPE_GROUP)
continue;
GetterForViewOrMut(field, is_mut);
msg.printer().PrintRaw("\n");
GetterForViewOrMut(ctx, field, is_mut);
ctx.printer().PrintRaw("\n");
}
}
} // namespace
void GenerateRs(Context<Descriptor> msg) {
if (msg.desc().map_key() != nullptr) {
ABSL_LOG(WARNING) << "unsupported map field: " << msg.desc().full_name();
void GenerateRs(Context& ctx, const Descriptor& msg) {
if (msg.map_key() != nullptr) {
ABSL_LOG(WARNING) << "unsupported map field: " << msg.full_name();
return;
}
msg.Emit(
{{"Msg", msg.desc().name()},
{"Msg::new", [&] { MessageNew(msg); }},
{"Msg::serialize", [&] { MessageSerialize(msg); }},
{"Msg::deserialize", [&] { MessageDeserialize(msg); }},
{"Msg::drop", [&] { MessageDrop(msg); }},
{"Msg_externs", [&] { MessageExterns(msg); }},
{"accessor_fns",
[&] {
for (int i = 0; i < msg.desc().field_count(); ++i) {
auto field = msg.WithDesc(*msg.desc().field(i));
msg.Emit({{"comment", FieldInfoComment(field)}}, R"rs(
ctx.Emit({{"Msg", msg.name()},
{"Msg::new", [&] { MessageNew(ctx, msg); }},
{"Msg::serialize", [&] { MessageSerialize(ctx, msg); }},
{"Msg::deserialize", [&] { MessageDeserialize(ctx, msg); }},
{"Msg::drop", [&] { MessageDrop(ctx, msg); }},
{"Msg_externs", [&] { MessageExterns(ctx, msg); }},
{"accessor_fns",
[&] {
for (int i = 0; i < msg.field_count(); ++i) {
auto& field = *msg.field(i);
ctx.Emit({{"comment", FieldInfoComment(ctx, field)}}, R"rs(
// $comment$
)rs");
GenerateAccessorMsgImpl(field);
msg.printer().PrintRaw("\n");
}
}},
{"oneof_accessor_fns",
[&] {
for (int i = 0; i < msg.desc().real_oneof_decl_count(); ++i) {
GenerateOneofAccessors(
msg.WithDesc(*msg.desc().real_oneof_decl(i)));
msg.printer().PrintRaw("\n");
}
}},
{"accessor_externs",
[&] {
for (int i = 0; i < msg.desc().field_count(); ++i) {
GenerateAccessorExternC(msg.WithDesc(*msg.desc().field(i)));
msg.printer().PrintRaw("\n");
}
}},
{"oneof_externs",
[&] {
for (int i = 0; i < msg.desc().real_oneof_decl_count(); ++i) {
GenerateOneofExternC(msg.WithDesc(*msg.desc().real_oneof_decl(i)));
msg.printer().PrintRaw("\n");
}
}},
{"nested_msgs",
[&] {
// If we have no nested types or oneofs, bail out without emitting
// an empty mod SomeMsg_.
if (msg.desc().nested_type_count() == 0 &&
msg.desc().real_oneof_decl_count() == 0) {
return;
}
msg.Emit(
{{"Msg", msg.desc().name()},
{"nested_msgs",
[&] {
for (int i = 0; i < msg.desc().nested_type_count(); ++i) {
auto nested_msg = msg.WithDesc(msg.desc().nested_type(i));
GenerateRs(nested_msg);
}
}},
{"oneofs",
[&] {
for (int i = 0; i < msg.desc().real_oneof_decl_count(); ++i) {
GenerateOneofDefinition(
msg.WithDesc(*msg.desc().real_oneof_decl(i)));
}
}}},
R"rs(
GenerateAccessorMsgImpl(ctx, field);
ctx.printer().PrintRaw("\n");
}
}},
{"oneof_accessor_fns",
[&] {
for (int i = 0; i < msg.real_oneof_decl_count(); ++i) {
GenerateOneofAccessors(ctx, *msg.real_oneof_decl(i));
ctx.printer().PrintRaw("\n");
}
}},
{"accessor_externs",
[&] {
for (int i = 0; i < msg.field_count(); ++i) {
GenerateAccessorExternC(ctx, *msg.field(i));
ctx.printer().PrintRaw("\n");
}
}},
{"oneof_externs",
[&] {
for (int i = 0; i < msg.real_oneof_decl_count(); ++i) {
GenerateOneofExternC(ctx, *msg.real_oneof_decl(i));
ctx.printer().PrintRaw("\n");
}
}},
{"nested_msgs",
[&] {
// If we have no nested types or oneofs, bail out without
// emitting an empty mod SomeMsg_.
if (msg.nested_type_count() == 0 &&
msg.real_oneof_decl_count() == 0) {
return;
}
ctx.Emit(
{{"Msg", msg.name()},
{"nested_msgs",
[&] {
for (int i = 0; i < msg.nested_type_count(); ++i) {
GenerateRs(ctx, *msg.nested_type(i));
}
}},
{"oneofs",
[&] {
for (int i = 0; i < msg.real_oneof_decl_count(); ++i) {
GenerateOneofDefinition(ctx, *msg.real_oneof_decl(i));
}
}}},
R"rs(
#[allow(non_snake_case)]
pub mod $Msg$_ {
$nested_msgs$
@ -396,10 +393,12 @@ void GenerateRs(Context<Descriptor> msg) {
$oneofs$
} // mod $Msg$_
)rs");
}},
{"accessor_fns_for_views", [&] { AccessorsForViewOrMut(msg, false); }},
{"accessor_fns_for_muts", [&] { AccessorsForViewOrMut(msg, true); }}},
R"rs(
}},
{"accessor_fns_for_views",
[&] { AccessorsForViewOrMut(ctx, msg, false); }},
{"accessor_fns_for_muts",
[&] { AccessorsForViewOrMut(ctx, msg, true); }}},
R"rs(
#[allow(non_camel_case_types)]
// TODO: Implement support for debug redaction
#[derive(Debug)]
@ -542,9 +541,9 @@ void GenerateRs(Context<Descriptor> msg) {
$nested_msgs$
)rs");
if (msg.is_cpp()) {
msg.printer().PrintRaw("\n");
msg.Emit({{"Msg", msg.desc().name()}}, R"rs(
if (ctx.is_cpp()) {
ctx.printer().PrintRaw("\n");
ctx.Emit({{"Msg", msg.name()}}, R"rs(
impl $Msg$ {
pub fn __unstable_wrap_cpp_grant_permission_to_break(msg: $pbi$::RawMessage) -> Self {
Self { inner: $pbr$::MessageInner { msg } }
@ -558,39 +557,37 @@ void GenerateRs(Context<Descriptor> msg) {
}
// Generates code for a particular message in `.pb.thunk.cc`.
void GenerateThunksCc(Context<Descriptor> msg) {
ABSL_CHECK(msg.is_cpp());
if (msg.desc().map_key() != nullptr) {
ABSL_LOG(WARNING) << "unsupported map field: " << msg.desc().full_name();
void GenerateThunksCc(Context& ctx, const Descriptor& msg) {
ABSL_CHECK(ctx.is_cpp());
if (msg.map_key() != nullptr) {
ABSL_LOG(WARNING) << "unsupported map field: " << msg.full_name();
return;
}
msg.Emit(
ctx.Emit(
{{"abi", "\"C\""}, // Workaround for syntax highlight bug in VSCode.
{"Msg", msg.desc().name()},
{"QualifiedMsg", cpp::QualifiedClassName(&msg.desc())},
{"new_thunk", Thunk(msg, "new")},
{"delete_thunk", Thunk(msg, "delete")},
{"serialize_thunk", Thunk(msg, "serialize")},
{"deserialize_thunk", Thunk(msg, "deserialize")},
{"Msg", msg.name()},
{"QualifiedMsg", cpp::QualifiedClassName(&msg)},
{"new_thunk", Thunk(ctx, msg, "new")},
{"delete_thunk", Thunk(ctx, msg, "delete")},
{"serialize_thunk", Thunk(ctx, msg, "serialize")},
{"deserialize_thunk", Thunk(ctx, msg, "deserialize")},
{"nested_msg_thunks",
[&] {
for (int i = 0; i < msg.desc().nested_type_count(); ++i) {
Context<Descriptor> nested_msg =
msg.WithDesc(msg.desc().nested_type(i));
GenerateThunksCc(nested_msg);
for (int i = 0; i < msg.nested_type_count(); ++i) {
GenerateThunksCc(ctx, *msg.nested_type(i));
}
}},
{"accessor_thunks",
[&] {
for (int i = 0; i < msg.desc().field_count(); ++i) {
GenerateAccessorThunkCc(msg.WithDesc(*msg.desc().field(i)));
for (int i = 0; i < msg.field_count(); ++i) {
GenerateAccessorThunkCc(ctx, *msg.field(i));
}
}},
{"oneof_thunks",
[&] {
for (int i = 0; i < msg.desc().real_oneof_decl_count(); ++i) {
GenerateOneofThunkCc(msg.WithDesc(*msg.desc().real_oneof_decl(i)));
for (int i = 0; i < msg.real_oneof_decl_count(); ++i) {
GenerateOneofThunkCc(ctx, *msg.real_oneof_decl(i));
}
}}},
R"cc(

@ -21,10 +21,10 @@ namespace compiler {
namespace rust {
// Generates code for a particular message in `.pb.rs`.
void GenerateRs(Context<Descriptor> msg);
void GenerateRs(Context& ctx, const Descriptor& msg);
// Generates code for a particular message in `.pb.thunk.cc`.
void GenerateThunksCc(Context<Descriptor> msg);
void GenerateThunksCc(Context& ctx, const Descriptor& msg);
} // namespace rust
} // namespace compiler

@ -26,22 +26,23 @@ namespace protobuf {
namespace compiler {
namespace rust {
namespace {
std::string GetUnderscoreDelimitedFullName(Context<Descriptor> msg) {
std::string result = msg.desc().full_name();
std::string GetUnderscoreDelimitedFullName(Context& ctx,
const Descriptor& msg) {
std::string result = msg.full_name();
absl::StrReplaceAll({{".", "_"}}, &result);
return result;
}
} // namespace
std::string GetCrateName(Context<FileDescriptor> dep) {
absl::string_view path = dep.desc().name();
std::string GetCrateName(Context& ctx, const FileDescriptor& dep) {
absl::string_view path = dep.name();
auto basename = path.substr(path.rfind('/') + 1);
return absl::StrReplaceAll(basename, {{".", "_"}, {"-", "_"}});
}
std::string GetRsFile(Context<FileDescriptor> file) {
auto basename = StripProto(file.desc().name());
switch (auto k = file.opts().kernel) {
std::string GetRsFile(Context& ctx, const FileDescriptor& file) {
auto basename = StripProto(file.name());
switch (auto k = ctx.opts().kernel) {
case Kernel::kUpb:
return absl::StrCat(basename, ".u.pb.rs");
case Kernel::kCpp:
@ -52,42 +53,41 @@ std::string GetRsFile(Context<FileDescriptor> file) {
}
}
std::string GetThunkCcFile(Context<FileDescriptor> file) {
auto basename = StripProto(file.desc().name());
std::string GetThunkCcFile(Context& ctx, const FileDescriptor& file) {
auto basename = StripProto(file.name());
return absl::StrCat(basename, ".pb.thunks.cc");
}
std::string GetHeaderFile(Context<FileDescriptor> file) {
auto basename = StripProto(file.desc().name());
std::string GetHeaderFile(Context& ctx, const FileDescriptor& file) {
auto basename = StripProto(file.name());
return absl::StrCat(basename, ".proto.h");
}
namespace {
template <typename T>
std::string FieldPrefix(Context<T> field) {
// NOTE: When field.is_upb(), this functions outputs must match the symbols
std::string FieldPrefix(Context& ctx, const T& field) {
// NOTE: When ctx.is_upb(), this functions outputs must match the symbols
// that the upbc plugin generates exactly. Failure to do so correctly results
// in a link-time failure.
absl::string_view prefix = field.is_cpp() ? "__rust_proto_thunk__" : "";
std::string thunk_prefix =
absl::StrCat(prefix, GetUnderscoreDelimitedFullName(
field.WithDesc(field.desc().containing_type())));
absl::string_view prefix = ctx.is_cpp() ? "__rust_proto_thunk__" : "";
std::string thunk_prefix = absl::StrCat(
prefix, GetUnderscoreDelimitedFullName(ctx, *field.containing_type()));
return thunk_prefix;
}
template <typename T>
std::string Thunk(Context<T> field, absl::string_view op) {
std::string thunk = FieldPrefix(field);
std::string Thunk(Context& ctx, const T& field, absl::string_view op) {
std::string thunk = FieldPrefix(ctx, field);
absl::string_view format;
if (field.is_upb() && op == "get") {
if (ctx.is_upb() && op == "get") {
// upb getter is simply the field name (no "get" in the name).
format = "_$1";
} else if (field.is_upb() && op == "get_mut") {
} else if (ctx.is_upb() && op == "get_mut") {
// same as above, with with `mutable` prefix
format = "_mutable_$1";
} else if (field.is_upb() && op == "case") {
} else if (ctx.is_upb() && op == "case") {
// some upb functions are in the order x_op compared to has/set/clear which
// are in the other order e.g. op_x.
format = "_$1_$0";
@ -95,51 +95,53 @@ std::string Thunk(Context<T> field, absl::string_view op) {
format = "_$0_$1";
}
absl::SubstituteAndAppend(&thunk, format, op, field.desc().name());
absl::SubstituteAndAppend(&thunk, format, op, field.name());
return thunk;
}
std::string ThunkMapOrRepeated(Context<FieldDescriptor> field,
std::string ThunkMapOrRepeated(Context& ctx, const FieldDescriptor& field,
absl::string_view op) {
if (!field.is_upb()) {
return Thunk<FieldDescriptor>(field, op);
if (!ctx.is_upb()) {
return Thunk<FieldDescriptor>(ctx, field, op);
}
std::string thunk = absl::StrCat("_", FieldPrefix(field));
std::string thunk = absl::StrCat("_", FieldPrefix(ctx, field));
absl::string_view format;
if (op == "get") {
format = field.desc().is_map() ? "_$1_upb_map" : "_$1_upb_array";
format = field.is_map() ? "_$1_upb_map" : "_$1_upb_array";
} else if (op == "get_mut") {
format =
field.desc().is_map() ? "_$1_mutable_upb_map" : "_$1_mutable_upb_array";
format = field.is_map() ? "_$1_mutable_upb_map" : "_$1_mutable_upb_array";
} else {
return Thunk<FieldDescriptor>(field, op);
return Thunk<FieldDescriptor>(ctx, field, op);
}
absl::SubstituteAndAppend(&thunk, format, op, field.desc().name());
absl::SubstituteAndAppend(&thunk, format, op, field.name());
return thunk;
}
} // namespace
std::string Thunk(Context<FieldDescriptor> field, absl::string_view op) {
if (field.desc().is_map() || field.desc().is_repeated()) {
return ThunkMapOrRepeated(field, op);
std::string Thunk(Context& ctx, const FieldDescriptor& field,
absl::string_view op) {
if (field.is_map() || field.is_repeated()) {
return ThunkMapOrRepeated(ctx, field, op);
}
return Thunk<FieldDescriptor>(field, op);
return Thunk<FieldDescriptor>(ctx, field, op);
}
std::string Thunk(Context<OneofDescriptor> field, absl::string_view op) {
return Thunk<OneofDescriptor>(field, op);
std::string Thunk(Context& ctx, const OneofDescriptor& field,
absl::string_view op) {
return Thunk<OneofDescriptor>(ctx, field, op);
}
std::string Thunk(Context<Descriptor> msg, absl::string_view op) {
absl::string_view prefix = msg.is_cpp() ? "__rust_proto_thunk__" : "";
return absl::StrCat(prefix, GetUnderscoreDelimitedFullName(msg), "_", op);
std::string Thunk(Context& ctx, const Descriptor& msg, absl::string_view op) {
absl::string_view prefix = ctx.is_cpp() ? "__rust_proto_thunk__" : "";
return absl::StrCat(prefix, GetUnderscoreDelimitedFullName(ctx, msg), "_",
op);
}
std::string PrimitiveRsTypeName(const FieldDescriptor& desc) {
switch (desc.type()) {
std::string PrimitiveRsTypeName(const FieldDescriptor& field) {
switch (field.type()) {
case FieldDescriptor::TYPE_BOOL:
return "bool";
case FieldDescriptor::TYPE_INT32:
@ -167,7 +169,7 @@ std::string PrimitiveRsTypeName(const FieldDescriptor& desc) {
default:
break;
}
ABSL_LOG(FATAL) << "Unsupported field type: " << desc.type_name();
ABSL_LOG(FATAL) << "Unsupported field type: " << field.type_name();
return "";
}
@ -180,20 +182,18 @@ std::string PrimitiveRsTypeName(const FieldDescriptor& desc) {
//
// If the message has no package and no containing messages then this returns
// empty string.
std::string RustModule(Context<Descriptor> msg) {
const Descriptor& desc = msg.desc();
std::string RustModule(Context& ctx, const Descriptor& msg) {
std::vector<std::string> modules;
std::vector<std::string> package_modules =
absl::StrSplit(desc.file()->package(), '.', absl::SkipEmpty());
absl::StrSplit(msg.file()->package(), '.', absl::SkipEmpty());
modules.insert(modules.begin(), package_modules.begin(),
package_modules.end());
// Innermost to outermost order.
std::vector<std::string> modules_from_containing_types;
const Descriptor* parent = desc.containing_type();
const Descriptor* parent = msg.containing_type();
while (parent != nullptr) {
modules_from_containing_types.push_back(absl::StrCat(parent->name(), "_"));
parent = parent->containing_type();
@ -213,27 +213,25 @@ std::string RustModule(Context<Descriptor> msg) {
return absl::StrJoin(modules, "::");
}
std::string RustInternalModuleName(Context<FileDescriptor> file) {
std::string RustInternalModuleName(Context& ctx, const FileDescriptor& file) {
// TODO: Introduce a more robust mangling here to avoid conflicts
// between `foo/bar/baz.proto` and `foo_bar/baz.proto`.
return absl::StrReplaceAll(StripProto(file.desc().name()), {{"/", "_"}});
return absl::StrReplaceAll(StripProto(file.name()), {{"/", "_"}});
}
std::string GetCrateRelativeQualifiedPath(Context<Descriptor> msg) {
return absl::StrCat(RustModule(msg), msg.desc().name());
std::string GetCrateRelativeQualifiedPath(Context& ctx, const Descriptor& msg) {
return absl::StrCat(RustModule(ctx, msg), msg.name());
}
std::string FieldInfoComment(Context<FieldDescriptor> field) {
absl::string_view label =
field.desc().is_repeated() ? "repeated" : "optional";
std::string comment =
absl::StrCat(field.desc().name(), ": ", label, " ",
FieldDescriptor::TypeName(field.desc().type()));
std::string FieldInfoComment(Context& ctx, const FieldDescriptor& field) {
absl::string_view label = field.is_repeated() ? "repeated" : "optional";
std::string comment = absl::StrCat(field.name(), ": ", label, " ",
FieldDescriptor::TypeName(field.type()));
if (auto* m = field.desc().message_type()) {
if (auto* m = field.message_type()) {
absl::StrAppend(&comment, " ", m->full_name());
}
if (auto* m = field.desc().enum_type()) {
if (auto* m = field.enum_type()) {
absl::StrAppend(&comment, " ", m->full_name());
}

@ -19,25 +19,27 @@ namespace google {
namespace protobuf {
namespace compiler {
namespace rust {
std::string GetCrateName(Context<FileDescriptor> dep);
std::string GetCrateName(Context& ctx, const FileDescriptor& dep);
std::string GetRsFile(Context<FileDescriptor> file);
std::string GetThunkCcFile(Context<FileDescriptor> file);
std::string GetHeaderFile(Context<FileDescriptor> file);
std::string GetRsFile(Context& ctx, const FileDescriptor& file);
std::string GetThunkCcFile(Context& ctx, const FileDescriptor& file);
std::string GetHeaderFile(Context& ctx, const FileDescriptor& file);
std::string Thunk(Context<FieldDescriptor> field, absl::string_view op);
std::string Thunk(Context<OneofDescriptor> field, absl::string_view op);
std::string Thunk(Context& ctx, const FieldDescriptor& field,
absl::string_view op);
std::string Thunk(Context& ctx, const OneofDescriptor& field,
absl::string_view op);
std::string Thunk(Context<Descriptor> msg, absl::string_view op);
std::string Thunk(Context& ctx, const Descriptor& msg, absl::string_view op);
std::string PrimitiveRsTypeName(const FieldDescriptor& desc);
std::string PrimitiveRsTypeName(const FieldDescriptor& field);
std::string FieldInfoComment(Context<FieldDescriptor> field);
std::string FieldInfoComment(Context& ctx, const FieldDescriptor& field);
std::string RustModule(Context<Descriptor> msg);
std::string RustInternalModuleName(Context<FileDescriptor> file);
std::string RustModule(Context& ctx, const Descriptor& msg);
std::string RustInternalModuleName(Context& ctx, const FileDescriptor& file);
std::string GetCrateRelativeQualifiedPath(Context<Descriptor> msg);
std::string GetCrateRelativeQualifiedPath(Context& ctx, const Descriptor& msg);
} // namespace rust
} // namespace compiler

@ -78,27 +78,25 @@ std::string ToCamelCase(absl::string_view name) {
return cpp::UnderscoresToCamelCase(name, /* upper initial letter */ true);
}
std::string oneofViewEnumRsName(const OneofDescriptor& desc) {
return ToCamelCase(desc.name());
std::string oneofViewEnumRsName(const OneofDescriptor& oneof) {
return ToCamelCase(oneof.name());
}
std::string oneofMutEnumRsName(const OneofDescriptor& desc) {
return ToCamelCase(desc.name()) + "Mut";
std::string oneofMutEnumRsName(const OneofDescriptor& oneof) {
return ToCamelCase(oneof.name()) + "Mut";
}
std::string oneofCaseEnumName(const OneofDescriptor& desc) {
std::string oneofCaseEnumName(const OneofDescriptor& oneof) {
// Note: This is the name used for the cpp Case enum, we use it for both
// the Rust Case enum as well as for the cpp case enum in the cpp thunk.
return ToCamelCase(desc.name()) + "Case";
return ToCamelCase(oneof.name()) + "Case";
}
std::string RsTypeNameView(Context<FieldDescriptor> field) {
const auto& desc = field.desc();
if (desc.options().has_ctype()) {
std::string RsTypeNameView(Context& ctx, const FieldDescriptor& field) {
if (field.options().has_ctype()) {
return ""; // TODO: b/308792377 - ctype fields not supported yet.
}
switch (desc.type()) {
switch (field.type()) {
case FieldDescriptor::TYPE_INT32:
case FieldDescriptor::TYPE_INT64:
case FieldDescriptor::TYPE_FIXED32:
@ -112,7 +110,7 @@ std::string RsTypeNameView(Context<FieldDescriptor> field) {
case FieldDescriptor::TYPE_FLOAT:
case FieldDescriptor::TYPE_DOUBLE:
case FieldDescriptor::TYPE_BOOL:
return PrimitiveRsTypeName(desc);
return PrimitiveRsTypeName(field);
case FieldDescriptor::TYPE_BYTES:
return "&'msg [u8]";
case FieldDescriptor::TYPE_STRING:
@ -120,23 +118,21 @@ std::string RsTypeNameView(Context<FieldDescriptor> field) {
case FieldDescriptor::TYPE_MESSAGE:
return absl::StrCat(
"::__pb::View<'msg, crate::",
GetCrateRelativeQualifiedPath(field.WithDesc(desc.message_type())),
">");
GetCrateRelativeQualifiedPath(ctx, *field.message_type()), ">");
case FieldDescriptor::TYPE_ENUM: // TODO: b/300257770 - Support enums.
case FieldDescriptor::TYPE_GROUP: // Not supported yet.
return "";
}
ABSL_LOG(FATAL) << "Unexpected field type: " << desc.type_name();
ABSL_LOG(FATAL) << "Unexpected field type: " << field.type_name();
return "";
}
std::string RsTypeNameMut(Context<FieldDescriptor> field) {
const auto& desc = field.desc();
if (desc.options().has_ctype()) {
std::string RsTypeNameMut(Context& ctx, const FieldDescriptor& field) {
if (field.options().has_ctype()) {
return ""; // TODO: b/308792377 - ctype fields not supported yet.
}
switch (desc.type()) {
switch (field.type()) {
case FieldDescriptor::TYPE_INT32:
case FieldDescriptor::TYPE_INT64:
case FieldDescriptor::TYPE_FIXED32:
@ -152,56 +148,54 @@ std::string RsTypeNameMut(Context<FieldDescriptor> field) {
case FieldDescriptor::TYPE_BOOL:
case FieldDescriptor::TYPE_BYTES:
case FieldDescriptor::TYPE_STRING:
return absl::StrCat("::__pb::Mut<'msg, ", PrimitiveRsTypeName(desc), ">");
return absl::StrCat("::__pb::Mut<'msg, ", PrimitiveRsTypeName(field),
">");
case FieldDescriptor::TYPE_MESSAGE:
return absl::StrCat(
"::__pb::Mut<'msg, crate::",
GetCrateRelativeQualifiedPath(field.WithDesc(desc.message_type())),
">");
GetCrateRelativeQualifiedPath(ctx, *field.message_type()), ">");
case FieldDescriptor::TYPE_ENUM: // TODO: b/300257770 - Support enums.
case FieldDescriptor::TYPE_GROUP: // Not supported yet.
return "";
}
ABSL_LOG(FATAL) << "Unexpected field type: " << desc.type_name();
ABSL_LOG(FATAL) << "Unexpected field type: " << field.type_name();
return "";
}
} // namespace
void GenerateOneofDefinition(Context<OneofDescriptor> oneof) {
const auto& desc = oneof.desc();
oneof.Emit(
{{"view_enum_name", oneofViewEnumRsName(desc)},
{"mut_enum_name", oneofMutEnumRsName(desc)},
void GenerateOneofDefinition(Context& ctx, const OneofDescriptor& oneof) {
ctx.Emit(
{{"view_enum_name", oneofViewEnumRsName(oneof)},
{"mut_enum_name", oneofMutEnumRsName(oneof)},
{"view_fields",
[&] {
for (int i = 0; i < desc.field_count(); ++i) {
const auto& field = *desc.field(i);
std::string rs_type = RsTypeNameView(oneof.WithDesc(field));
for (int i = 0; i < oneof.field_count(); ++i) {
auto& field = *oneof.field(i);
std::string rs_type = RsTypeNameView(ctx, field);
if (rs_type.empty()) {
continue;
}
oneof.Emit({{"name", ToCamelCase(field.name())},
{"type", rs_type},
{"number", std::to_string(field.number())}},
R"rs($name$($type$) = $number$,
ctx.Emit({{"name", ToCamelCase(field.name())},
{"type", rs_type},
{"number", std::to_string(field.number())}},
R"rs($name$($type$) = $number$,
)rs");
}
}},
{"mut_fields",
[&] {
for (int i = 0; i < desc.field_count(); ++i) {
const auto& field = *desc.field(i);
std::string rs_type = RsTypeNameMut(oneof.WithDesc(field));
for (int i = 0; i < oneof.field_count(); ++i) {
auto& field = *oneof.field(i);
std::string rs_type = RsTypeNameMut(ctx, field);
if (rs_type.empty()) {
continue;
}
oneof.Emit({{"name", ToCamelCase(field.name())},
{"type", rs_type},
{"number", std::to_string(field.number())}},
R"rs($name$($type$) = $number$,
ctx.Emit({{"name", ToCamelCase(field.name())},
{"type", rs_type},
{"number", std::to_string(field.number())}},
R"rs($name$($type$) = $number$,
)rs");
}
}}},
@ -236,18 +230,18 @@ void GenerateOneofDefinition(Context<OneofDescriptor> oneof) {
// Note: This enum is used as the Thunk return type for getting which case is
// used: it exactly matches the generate case enum that both cpp and upb use.
oneof.Emit({{"case_enum_name", oneofCaseEnumName(desc)},
{"cases",
[&] {
for (int i = 0; i < desc.field_count(); ++i) {
const auto& field = desc.field(i);
oneof.Emit({{"name", ToCamelCase(field->name())},
{"number", std::to_string(field->number())}},
R"rs($name$ = $number$,
ctx.Emit({{"case_enum_name", oneofCaseEnumName(oneof)},
{"cases",
[&] {
for (int i = 0; i < oneof.field_count(); ++i) {
auto& field = *oneof.field(i);
ctx.Emit({{"name", ToCamelCase(field.name())},
{"number", std::to_string(field.number())}},
R"rs($name$ = $number$,
)rs");
}
}}},
R"rs(
}
}}},
R"rs(
#[repr(C)]
#[derive(Debug, Copy, Clone, PartialEq, Eq)]
pub(super) enum $case_enum_name$ {
@ -260,23 +254,21 @@ void GenerateOneofDefinition(Context<OneofDescriptor> oneof) {
)rs");
}
void GenerateOneofAccessors(Context<OneofDescriptor> oneof) {
const auto& desc = oneof.desc();
oneof.Emit(
{{"oneof_name", desc.name()},
{"view_enum_name", oneofViewEnumRsName(desc)},
{"mut_enum_name", oneofMutEnumRsName(desc)},
{"case_enum_name", oneofCaseEnumName(desc)},
void GenerateOneofAccessors(Context& ctx, const OneofDescriptor& oneof) {
ctx.Emit(
{{"oneof_name", oneof.name()},
{"view_enum_name", oneofViewEnumRsName(oneof)},
{"mut_enum_name", oneofMutEnumRsName(oneof)},
{"case_enum_name", oneofCaseEnumName(oneof)},
{"view_cases",
[&] {
for (int i = 0; i < desc.field_count(); ++i) {
const auto& field = *desc.field(i);
std::string rs_type = RsTypeNameView(oneof.WithDesc(field));
for (int i = 0; i < oneof.field_count(); ++i) {
auto& field = *oneof.field(i);
std::string rs_type = RsTypeNameView(ctx, field);
if (rs_type.empty()) {
continue;
}
oneof.Emit(
ctx.Emit(
{
{"case", ToCamelCase(field.name())},
{"rs_getter", field.name()},
@ -288,13 +280,13 @@ void GenerateOneofAccessors(Context<OneofDescriptor> oneof) {
}},
{"mut_cases",
[&] {
for (int i = 0; i < desc.field_count(); ++i) {
const auto& field = *desc.field(i);
std::string rs_type = RsTypeNameMut(oneof.WithDesc(field));
for (int i = 0; i < oneof.field_count(); ++i) {
auto& field = *oneof.field(i);
std::string rs_type = RsTypeNameMut(ctx, field);
if (rs_type.empty()) {
continue;
}
oneof.Emit(
ctx.Emit(
{{"case", ToCamelCase(field.name())},
{"rs_mut_getter", field.name() + "_mut"},
{"type", rs_type},
@ -321,7 +313,7 @@ void GenerateOneofAccessors(Context<OneofDescriptor> oneof) {
$Msg$_::$mut_enum_name$::$case$(self.$rs_mut_getter$()$into_mut_transform$), )rs");
}
}},
{"case_thunk", Thunk(oneof, "case")}},
{"case_thunk", Thunk(ctx, oneof, "case")}},
R"rs(
pub fn r#$oneof_name$(&self) -> $Msg$_::$view_enum_name$ {
match unsafe { $case_thunk$(self.inner.msg) } {
@ -340,26 +332,24 @@ void GenerateOneofAccessors(Context<OneofDescriptor> oneof) {
)rs");
}
void GenerateOneofExternC(Context<OneofDescriptor> oneof) {
const auto& desc = oneof.desc();
oneof.Emit(
void GenerateOneofExternC(Context& ctx, const OneofDescriptor& oneof) {
ctx.Emit(
{
{"case_enum_rs_name", oneofCaseEnumName(desc)},
{"case_thunk", Thunk(oneof, "case")},
{"case_enum_rs_name", oneofCaseEnumName(oneof)},
{"case_thunk", Thunk(ctx, oneof, "case")},
},
R"rs(
fn $case_thunk$(raw_msg: $pbi$::RawMessage) -> $Msg$_::$case_enum_rs_name$;
)rs");
}
void GenerateOneofThunkCc(Context<OneofDescriptor> oneof) {
const auto& desc = oneof.desc();
oneof.Emit(
void GenerateOneofThunkCc(Context& ctx, const OneofDescriptor& oneof) {
ctx.Emit(
{
{"oneof_name", desc.name()},
{"case_enum_name", oneofCaseEnumName(desc)},
{"case_thunk", Thunk(oneof, "case")},
{"QualifiedMsg", cpp::QualifiedClassName(desc.containing_type())},
{"oneof_name", oneof.name()},
{"case_enum_name", oneofCaseEnumName(oneof)},
{"case_thunk", Thunk(ctx, oneof, "case")},
{"QualifiedMsg", cpp::QualifiedClassName(oneof.containing_type())},
},
R"cc(
$QualifiedMsg$::$case_enum_name$ $case_thunk$($QualifiedMsg$* msg) {

@ -16,10 +16,10 @@ namespace protobuf {
namespace compiler {
namespace rust {
void GenerateOneofDefinition(Context<OneofDescriptor> oneof);
void GenerateOneofAccessors(Context<OneofDescriptor> oneof);
void GenerateOneofExternC(Context<OneofDescriptor> oneof);
void GenerateOneofThunkCc(Context<OneofDescriptor> oneof);
void GenerateOneofDefinition(Context& ctx, const OneofDescriptor& oneof);
void GenerateOneofAccessors(Context& ctx, const OneofDescriptor& oneof);
void GenerateOneofExternC(Context& ctx, const OneofDescriptor& oneof);
void GenerateOneofThunkCc(Context& ctx, const OneofDescriptor& oneof);
} // namespace rust
} // namespace compiler

Loading…
Cancel
Save