Add Repeated<Message> accessors

Shares logic with Repeated<Scalar> accessors.

PiperOrigin-RevId: 599263714
pull/15434/head
Kevin King 1 year ago committed by Copybara-Service
parent b04a213326
commit 9bd8dfa639
  1. 29
      rust/test/shared/accessors_repeated_test.rs
  2. 4
      src/google/protobuf/compiler/rust/accessors/accessor_generator.h
  3. 10
      src/google/protobuf/compiler/rust/accessors/accessors.cc
  4. 63
      src/google/protobuf/compiler/rust/accessors/repeated_field.cc

@ -7,7 +7,7 @@
use googletest::prelude::*;
use paste::paste;
use unittest_proto::proto2_unittest::{TestAllTypes, TestAllTypes_};
use unittest_proto::proto2_unittest::{TestAllTypes, TestAllTypes_, TestAllTypes_::NestedMessage};
macro_rules! generate_repeated_numeric_test {
($(($t: ty, $field: ident)),*) => {
@ -161,3 +161,30 @@ fn test_repeated_bool_set() {
assert_that!(mutator.iter().collect::<Vec<_>>(), eq(mutator2.iter().collect::<Vec<_>>()));
}
#[test]
fn test_repeated_message() {
let mut msg = TestAllTypes::new();
assert_that!(msg.repeated_nested_message().len(), eq(0));
let mut nested = NestedMessage::new();
nested.bb_mut().set(1);
msg.repeated_nested_message_mut().push(nested.as_view());
assert_that!(msg.repeated_nested_message().get(0).unwrap().bb(), eq(1));
let mut msg2 = TestAllTypes::new();
msg2.repeated_nested_message_mut().copy_from(msg.repeated_nested_message());
assert_that!(msg2.repeated_nested_message().get(0).unwrap().bb(), eq(1));
msg2.repeated_nested_message_mut().clear();
assert_that!(msg2.repeated_nested_message().len(), eq(0));
let mut nested2 = NestedMessage::new();
nested2.bb_mut().set(2);
msg.repeated_nested_message_mut().set(0, nested2.as_view());
assert_that!(msg.repeated_nested_message().get(0).unwrap().bb(), eq(2));
assert_that!(
msg.repeated_nested_message().iter().map(|m| m.bb()).collect::<Vec<_>>(),
eq(vec![2]),
);
}

@ -103,9 +103,9 @@ class SingularMessage final : public AccessorGenerator {
void InThunkCc(Context& ctx, const FieldDescriptor& field) const override;
};
class RepeatedScalar final : public AccessorGenerator {
class RepeatedField final : public AccessorGenerator {
public:
~RepeatedScalar() override = default;
~RepeatedField() override = default;
void InMsgImpl(Context& ctx, const FieldDescriptor& field,
AccessorCase accessor_case) const override;
void InExternC(Context& ctx, const FieldDescriptor& field) const override;

@ -59,7 +59,7 @@ std::unique_ptr<AccessorGenerator> AccessorGeneratorFor(
case FieldDescriptor::TYPE_DOUBLE:
case FieldDescriptor::TYPE_BOOL:
if (field.is_repeated()) {
return std::make_unique<RepeatedScalar>();
return std::make_unique<RepeatedField>();
}
return std::make_unique<SingularScalar>();
case FieldDescriptor::TYPE_ENUM:
@ -70,7 +70,7 @@ std::unique_ptr<AccessorGenerator> AccessorGeneratorFor(
" (defined in a separate Rust crate) are not supported");
}
if (field.is_repeated()) {
return std::make_unique<RepeatedScalar>();
return std::make_unique<RepeatedField>();
}
return std::make_unique<SingularScalar>();
case FieldDescriptor::TYPE_BYTES:
@ -80,15 +80,15 @@ std::unique_ptr<AccessorGenerator> AccessorGeneratorFor(
}
return std::make_unique<SingularString>();
case FieldDescriptor::TYPE_MESSAGE:
if (field.is_repeated()) {
return std::make_unique<UnsupportedField>("repeated msg not supported");
}
// TODO: support messages which are defined in other crates.
if (!IsInCurrentlyGeneratingCrate(ctx, *field.message_type())) {
return std::make_unique<UnsupportedField>(
"message fields that are imported from another proto_library"
" (defined in a separate Rust crate) are not supported");
}
if (field.is_repeated()) {
return std::make_unique<RepeatedField>();
}
return std::make_unique<SingularMessage>();
case FieldDescriptor::TYPE_GROUP:

@ -5,6 +5,8 @@
// license that can be found in the LICENSE file or at
// https://developers.google.com/open-source/licenses/bsd
#include <string>
#include "absl/strings/string_view.h"
#include "google/protobuf/compiler/cpp/helpers.h"
#include "google/protobuf/compiler/rust/accessors/accessor_case.h"
@ -18,24 +20,24 @@ namespace protobuf {
namespace compiler {
namespace rust {
void RepeatedScalar::InMsgImpl(Context& ctx, const FieldDescriptor& field,
AccessorCase accessor_case) const {
void RepeatedField::InMsgImpl(Context& ctx, const FieldDescriptor& field,
AccessorCase accessor_case) const {
ctx.Emit({{"field", RsSafeName(field.name())},
{"Scalar", RsTypePath(ctx, field)},
{"RsType", RsTypePath(ctx, field)},
{"getter_thunk", ThunkName(ctx, field, "get")},
{"getter_mut_thunk", ThunkName(ctx, field, "get_mut")},
{"getter",
[&] {
if (ctx.is_upb()) {
ctx.Emit({}, R"rs(
pub fn $field$(&self) -> $pb$::RepeatedView<'_, $Scalar$> {
pub fn $field$(&self) -> $pb$::RepeatedView<'_, $RsType$> {
unsafe {
$getter_thunk$(
self.raw_msg(),
/* optional size pointer */ std::ptr::null(),
) }
.map_or_else(
$pbr$::empty_array::<$Scalar$>,
$pbr$::empty_array::<$RsType$>,
|raw| unsafe {
$pb$::RepeatedView::from_raw($pbi$::Private, raw)
}
@ -44,7 +46,7 @@ void RepeatedScalar::InMsgImpl(Context& ctx, const FieldDescriptor& field,
)rs");
} else {
ctx.Emit({}, R"rs(
pub fn $field$(&self) -> $pb$::RepeatedView<'_, $Scalar$> {
pub fn $field$(&self) -> $pb$::RepeatedView<'_, $RsType$> {
unsafe {
$pb$::RepeatedView::from_raw(
$pbi$::Private,
@ -63,7 +65,7 @@ void RepeatedScalar::InMsgImpl(Context& ctx, const FieldDescriptor& field,
}
if (ctx.is_upb()) {
ctx.Emit({}, R"rs(
pub fn $field$_mut(&mut self) -> $pb$::RepeatedMut<'_, $Scalar$> {
pub fn $field$_mut(&mut self) -> $pb$::RepeatedMut<'_, $RsType$> {
unsafe {
$pb$::RepeatedMut::from_inner(
$pbi$::Private,
@ -82,7 +84,7 @@ void RepeatedScalar::InMsgImpl(Context& ctx, const FieldDescriptor& field,
)rs");
} else {
ctx.Emit({}, R"rs(
pub fn $field$_mut(&mut self) -> $pb$::RepeatedMut<'_, $Scalar$> {
pub fn $field$_mut(&mut self) -> $pb$::RepeatedMut<'_, $RsType$> {
unsafe {
$pb$::RepeatedMut::from_inner(
$pbi$::Private,
@ -102,10 +104,9 @@ void RepeatedScalar::InMsgImpl(Context& ctx, const FieldDescriptor& field,
)rs");
}
void RepeatedScalar::InExternC(Context& ctx,
const FieldDescriptor& field) const {
ctx.Emit({{"Scalar", RsTypePath(ctx, field)},
{"getter_thunk", ThunkName(ctx, field, "get")},
void RepeatedField::InExternC(Context& ctx,
const FieldDescriptor& field) const {
ctx.Emit({{"getter_thunk", ThunkName(ctx, field, "get")},
{"getter_mut_thunk", ThunkName(ctx, field, "get_mut")},
{"getter",
[&] {
@ -136,10 +137,38 @@ void RepeatedScalar::InExternC(Context& ctx,
)rs");
}
void RepeatedScalar::InThunkCc(Context& ctx,
const FieldDescriptor& field) const {
bool IsRepeatedPrimitive(const FieldDescriptor& field) {
return field.cpp_type() == FieldDescriptor::CPPTYPE_ENUM ||
field.cpp_type() == FieldDescriptor::CPPTYPE_BOOL ||
field.cpp_type() == FieldDescriptor::CPPTYPE_DOUBLE ||
field.cpp_type() == FieldDescriptor::CPPTYPE_FLOAT ||
field.cpp_type() == FieldDescriptor::CPPTYPE_INT32 ||
field.cpp_type() == FieldDescriptor::CPPTYPE_INT64 ||
field.cpp_type() == FieldDescriptor::CPPTYPE_UINT32 ||
field.cpp_type() == FieldDescriptor::CPPTYPE_UINT64;
}
std::string CppElementType(const FieldDescriptor& field) {
if (IsRepeatedPrimitive(field)) {
return cpp::PrimitiveTypeName(field.cpp_type());
} else {
return cpp::QualifiedClassName(field.message_type());
}
}
const char* CppRepeatedContainerType(const FieldDescriptor& field) {
if (IsRepeatedPrimitive(field)) {
return "google::protobuf::RepeatedField";
} else {
return "google::protobuf::RepeatedPtrField";
}
}
void RepeatedField::InThunkCc(Context& ctx,
const FieldDescriptor& field) const {
ctx.Emit({{"field", cpp::FieldName(&field)},
{"Scalar", cpp::PrimitiveTypeName(field.cpp_type())},
{"ElementType", CppElementType(field)},
{"ContainerType", CppRepeatedContainerType(field)},
{"QualifiedMsg", cpp::QualifiedClassName(field.containing_type())},
{"clearer_thunk", ThunkName(ctx, field, "clear")},
{"getter_thunk", ThunkName(ctx, field, "get")},
@ -151,10 +180,10 @@ void RepeatedScalar::InThunkCc(Context& ctx,
void $clearer_thunk$($QualifiedMsg$* msg) {
msg->clear_$field$();
}
google::protobuf::RepeatedField<$Scalar$>* $getter_mut_thunk$($QualifiedMsg$* msg) {
$ContainerType$<$ElementType$>* $getter_mut_thunk$($QualifiedMsg$* msg) {
return msg->mutable_$field$();
}
const google::protobuf::RepeatedField<$Scalar$>* $getter_thunk$(
const $ContainerType$<$ElementType$>* $getter_thunk$(
const $QualifiedMsg$* msg) {
return &msg->$field$();
}

Loading…
Cancel
Save