Remove the oneof_mut accessor and expose the oneof_case accessor on gencode api.

PiperOrigin-RevId: 615782249
pull/16137/head
Protobuf Team Bot 9 months ago committed by Copybara-Service
parent 2f0fe0641c
commit 7ec56d4243
  1. 62
      rust/test/shared/accessors_proto3_test.rs
  2. 118
      rust/test/shared/accessors_test.rs
  3. 237
      src/google/protobuf/compiler/rust/oneof.cc

@ -307,74 +307,18 @@ fn test_oneof_accessors_view_long_lifetime() {
fn test_oneof_enum_accessors() {
use unittest_proto3::{
TestOneof2,
TestOneof2_::{Foo, NestedEnum},
TestOneof2_::{Foo, FooCase, NestedEnum},
};
let mut msg = TestOneof2::new();
assert_that!(msg.foo_enum_opt(), eq(Optional::Unset(NestedEnum::Unknown)));
assert_that!(msg.foo(), matches_pattern!(Foo::not_set(_)));
assert_that!(msg.foo_case(), matches_pattern!(FooCase::not_set));
msg.set_foo_enum(NestedEnum::Bar);
assert_that!(msg.foo_enum_opt(), eq(Optional::Set(NestedEnum::Bar)));
assert_that!(msg.foo(), matches_pattern!(Foo::FooEnum(eq(NestedEnum::Bar))));
}
#[test]
fn test_oneof_mut_accessors() {
use TestAllTypes_::OneofFieldMut::*;
let mut msg = TestAllTypes::new();
assert_that!(msg.oneof_field_mut(), matches_pattern!(not_set(_)));
msg.set_oneof_uint32(7);
match msg.oneof_field_mut() {
OneofUint32(mut v) => {
assert_that!(v.get(), eq(7));
v.set(8);
assert_that!(v.get(), eq(8));
}
f => panic!("unexpected field_mut type! {:?}", f),
}
// Confirm that the mut write above applies to both the field accessor and the
// oneof view accessor.
assert_that!(msg.oneof_uint32_opt(), eq(Optional::Set(8)));
assert_that!(
msg.oneof_field(),
matches_pattern!(TestAllTypes_::OneofField::OneofUint32(eq(8)))
);
// Clearing a different field in the same oneof doesn't affect the other, set
// field.
msg.clear_oneof_bytes();
assert_that!(
msg.oneof_field(),
matches_pattern!(TestAllTypes_::OneofField::OneofUint32(eq(8)))
);
msg.clear_oneof_uint32();
assert_that!(msg.oneof_field_mut(), matches_pattern!(not_set(_)));
msg.set_oneof_uint32(7);
msg.set_oneof_bytes(b"123");
assert_that!(msg.oneof_field_mut(), matches_pattern!(OneofBytes(_)));
}
#[test]
fn test_oneof_mut_enum_accessors() {
use unittest_proto3::{
TestOneof2,
TestOneof2_::{FooMut, NestedEnum},
};
let mut msg = TestOneof2::new();
assert_that!(msg.foo_enum_opt(), eq(Optional::Unset(NestedEnum::Unknown)));
assert_that!(msg.foo_mut(), matches_pattern!(FooMut::not_set(_)));
msg.set_foo_enum(NestedEnum::Bar);
assert_that!(msg.foo_enum_opt(), eq(Optional::Set(NestedEnum::Bar)));
assert_that!(msg.foo_mut(), matches_pattern!(FooMut::FooEnum(_)));
assert_that!(msg.foo_case(), matches_pattern!(FooCase::FooEnum));
}
#[test]

@ -924,27 +924,32 @@ fn test_default_import_enum_accessors() {
#[test]
fn test_oneof_accessors() {
use unittest_proto::TestOneof2;
use unittest_proto::TestOneof2_::{Foo::*, NestedEnum};
use unittest_proto::TestOneof2_::{Foo::*, FooCase, NestedEnum};
let mut msg = TestOneof2::new();
assert_that!(msg.foo(), matches_pattern!(not_set(_)));
assert_that!(msg.foo_case(), eq(FooCase::not_set));
msg.foo_int_mut().set(7);
assert_that!(msg.foo_int_opt(), eq(Optional::Set(7)));
assert_that!(msg.foo(), matches_pattern!(FooInt(eq(7))));
assert_that!(msg.foo_case(), eq(FooCase::FooInt));
msg.foo_int_mut().clear();
assert_that!(msg.foo_int_opt(), eq(Optional::Unset(0)));
assert_that!(msg.foo(), matches_pattern!(not_set(_)));
assert_that!(msg.foo_case(), eq(FooCase::not_set));
msg.foo_int_mut().set(7);
msg.foo_bytes_mut().set(b"123");
assert_that!(msg.foo_int_opt(), eq(Optional::Unset(0)));
assert_that!(msg.foo(), matches_pattern!(FooBytes(eq(b"123"))));
assert_that!(msg.foo_case(), eq(FooCase::FooBytes));
msg.foo_enum_mut().set(NestedEnum::Foo);
assert_that!(msg.foo(), matches_pattern!(FooEnum(eq(NestedEnum::Foo))));
assert_that!(msg.foo_case(), eq(FooCase::FooEnum));
// Test the accessors or $Msg$Mut
let mut msg_mut = msg.as_mut();
@ -952,68 +957,21 @@ fn test_oneof_accessors() {
msg_mut.foo_int_mut().set(7);
msg_mut.foo_bytes_mut().set(b"123");
assert_that!(msg_mut.foo(), matches_pattern!(FooBytes(eq(b"123"))));
assert_that!(msg_mut.foo_case(), eq(FooCase::FooBytes));
assert_that!(msg_mut.foo_int_opt(), eq(Optional::Unset(0)));
// Test the accessors on $Msg$View
let msg_view = msg.as_view();
assert_that!(msg_view.foo(), matches_pattern!(FooBytes(eq(b"123"))));
assert_that!(msg_view.foo_case(), eq(FooCase::FooBytes));
assert_that!(msg_view.foo_int_opt(), eq(Optional::Unset(0)));
// TODO: Add tests covering a message-type field in a oneof.
}
#[test]
fn test_oneof_mut_accessors() {
use unittest_proto::TestOneof2;
use unittest_proto::TestOneof2_::{Foo, FooMut::*, NestedEnum};
let mut msg = TestOneof2::new();
assert_that!(msg.foo_mut(), matches_pattern!(not_set(_)));
msg.foo_int_mut().set(7);
match msg.foo_mut() {
FooInt(mut v) => {
assert_that!(v.get(), eq(7));
v.set(8);
assert_that!(v.get(), eq(8));
}
f => panic!("unexpected field_mut type! {:?}", f),
}
// Confirm that the mut write above applies to both the field accessor and the
// oneof view accessor.
assert_that!(msg.foo_int_opt(), eq(Optional::Set(8)));
assert_that!(msg.foo(), matches_pattern!(Foo::FooInt(_)));
msg.foo_int_mut().clear();
assert_that!(msg.foo_mut(), matches_pattern!(not_set(_)));
msg.foo_int_mut().set(7);
msg.foo_bytes_mut().set(b"123");
assert_that!(msg.foo_mut(), matches_pattern!(FooBytes(_)));
msg.foo_enum_mut().set(NestedEnum::Baz);
assert_that!(msg.foo_mut(), matches_pattern!(FooEnum(_)));
// Test the mut accessors or $Msg$Mut
let mut msg_mut = msg.as_mut();
match msg_mut.foo_mut() {
FooEnum(mut v) => {
assert_that!(v.get(), eq(NestedEnum::Baz));
v.set(NestedEnum::Bar);
assert_that!(v.get(), eq(NestedEnum::Bar));
}
f => panic!("unexpected field_mut type! {:?}", f),
}
assert_that!(msg.foo_enum(), eq(NestedEnum::Bar));
// TODO: Add tests covering a message-type field in a oneof.
}
#[test]
fn test_msg_oneof_default_accessors() {
use unittest_proto::TestOneof2_::{Bar::*, NestedEnum};
use unittest_proto::TestOneof2_::{Bar::*, BarCase, NestedEnum};
let mut msg = unittest_proto::TestOneof2::new();
assert_that!(msg.bar(), matches_pattern!(not_set(_)));
@ -1021,78 +979,28 @@ fn test_msg_oneof_default_accessors() {
msg.bar_int_mut().set(7);
assert_that!(msg.bar_int_opt(), eq(Optional::Set(7)));
assert_that!(msg.bar(), matches_pattern!(BarInt(eq(7))));
assert_that!(msg.bar_case(), eq(BarCase::BarInt));
msg.bar_int_mut().clear();
assert_that!(msg.bar_int_opt(), eq(Optional::Unset(5)));
assert_that!(msg.bar(), matches_pattern!(not_set(_)));
assert_that!(msg.bar_case(), eq(BarCase::not_set));
msg.bar_int_mut().set(7);
msg.bar_bytes_mut().set(b"123");
assert_that!(msg.bar_int_opt(), eq(Optional::Unset(5)));
assert_that!(msg.bar_enum_opt(), eq(Optional::Unset(NestedEnum::Bar)));
assert_that!(msg.bar(), matches_pattern!(BarBytes(eq(b"123"))));
assert_that!(msg.bar_case(), eq(BarCase::BarBytes));
msg.bar_enum_mut().set(NestedEnum::Baz);
assert_that!(msg.bar(), matches_pattern!(BarEnum(eq(NestedEnum::Baz))));
assert_that!(msg.bar_case(), eq(BarCase::BarEnum));
assert_that!(msg.bar_int_opt(), eq(Optional::Unset(5)));
// TODO: Add tests covering a message-type field in a oneof.
}
#[test]
fn test_oneof_default_mut_accessors() {
use unittest_proto::TestOneof2_::{Bar, BarMut, BarMut::*, NestedEnum};
let mut msg = unittest_proto::TestOneof2::new();
assert_that!(msg.bar_mut(), matches_pattern!(not_set(_)));
msg.bar_int_mut().set(7);
match msg.bar_mut() {
BarInt(mut v) => {
assert_that!(v.get(), eq(7));
v.set(8);
assert_that!(v.get(), eq(8));
}
f => panic!("unexpected field_mut type! {:?}", f),
}
// Confirm that the mut write above applies to all three of:
// - The field accessor
// - The oneof mut accessor
// - The oneof view accessor
// And then each of the applicable cases on:
// - The owned msg directly
// - The msg as a $Msg$Mut
// - The msg as a $Msg$View
assert_that!(msg.bar_int_opt(), eq(Optional::Set(8)));
assert_that!(msg.bar_mut(), matches_pattern!(BarMut::BarInt(_)));
assert_that!(msg.bar(), matches_pattern!(Bar::BarInt(_)));
let mut msg_mut = msg.as_mut();
assert_that!(msg_mut.bar_int_opt(), eq(Optional::Set(8)));
assert_that!(msg_mut.bar_mut(), matches_pattern!(BarMut::BarInt(_)));
assert_that!(msg_mut.bar(), matches_pattern!(Bar::BarInt(_)));
let msg_view = msg.as_view();
assert_that!(msg_view.bar_int_opt(), eq(Optional::Set(8)));
// This test correctly fails to compile if this line is uncommented:
// assert_that!(msg_view.bar_mut(), matches_pattern!(BarMut::BarInt(_)));
assert_that!(msg_view.bar(), matches_pattern!(Bar::BarInt(_)));
msg.bar_int_mut().clear();
assert_that!(msg.bar_mut(), matches_pattern!(not_set(_)));
msg.bar_int_mut().set(7);
msg.bar_bytes_mut().set(b"123");
assert_that!(msg.bar_mut(), matches_pattern!(BarBytes(_)));
msg.bar_enum_mut().set(NestedEnum::Baz);
assert_that!(msg.bar_mut(), matches_pattern!(BarEnum(_)));
// TODO: Add tests covering a message-type field in a oneof.
}
#[test]
fn test_set_message_from_view() {
use protobuf::MutProxy;

@ -24,17 +24,12 @@ namespace protobuf {
namespace compiler {
namespace rust {
// We emit three Rust enums:
// For each oneof we emit two Rust enums with corresponding accessors:
// - An enum acting as a tagged union that has each case holds a View<> of
// each of the cases. Named as the one_of name in CamelCase.
// - An enum acting as a tagged union that has each case holds a Mut<> of
// each of the cases. Named as one_of name in CamelCase with "Mut" appended.
// [TODO: Mut not implemented yet].
// - A simple enum whose cases have int values matching the cpp or upb's
// case enum. Named as the one_of camelcase with "Case" appended.
// All three contain cases matching the fields in the oneof CamelCased.
// The first and second are exposed in the API, the third is internal and
// used for interop with the Kernels in the generation of the other two.
// - A simple 'which oneof field is set' enum which directly maps to the
// underlying enum used for the 'cases' accessor in C++ or upb. Named as the
// one_of camelcase with "Case" appended.
//
// Example:
// For this oneof:
@ -47,28 +42,30 @@ namespace rust {
//
// This will emit as the exposed API:
// pub mod SomeMsg_ {
// // The 'view' struct (no suffix on the name)
// pub enum SomeOneof<'msg> {
// FieldA(i32) = 7,
// FieldB(View<'msg, SomeMsg>) = 9,
// not_set(std::marker::PhantomData<&'msg ()>) = 0
// }
// pub enum SomeOneofMut<'msg> {
// FieldA(Mut<'msg, i32>) = 7,
// FieldB(Mut<'msg, SomeMsg>) = 9,
// not_set(std::marker::PhantomData<&'msg ()>) = 0
//
// #[repr(C)]
// pub enum SomeOneofCase {
// FieldA = 7,
// FieldB = 9,
// not_set = 0
// }
// }
// impl SomeMsg {
// pub fn some_oneof(&self) -> SomeOneof {...}
// pub fn some_oneof_mut(&mut self) -> SomeOneofMut {...}
// pub fn some_oneof_case(&self) -> SomeOneofCase {...}
// }
// impl SomeMsgMut {
// pub fn some_oneof(&self) -> SomeOneof {...}
// pub fn some_oneof_mut(&mut self) -> SomeOneofMut {...}
// pub fn some_oneof_case(&self) -> SomeOneofCase {...}
// }
// impl SomeMsgView {
// pub fn some_oneof(&self) -> SomeOneof {...}
// pub fn some_oneof(self) -> SomeOneof {...}
// pub fn some_oneof_case(self) -> SomeOneofCase {...}
// }
//
// An additional "Case" enum which just reflects the corresponding slot numbers
@ -110,71 +107,28 @@ std::string RsTypeNameView(Context& ctx, const FieldDescriptor& field) {
return "";
}
// A user-friendly rust type for a mutator of this field with lifetime 'msg.
std::string RsTypeNameMut(Context& ctx, const FieldDescriptor& field) {
if (field.options().has_ctype()) {
return ""; // TODO: b/308792377 - ctype fields not supported yet.
}
switch (GetRustFieldType(field)) {
case RustFieldType::INT32:
case RustFieldType::INT64:
case RustFieldType::UINT32:
case RustFieldType::UINT64:
case RustFieldType::FLOAT:
case RustFieldType::DOUBLE:
case RustFieldType::BOOL:
return absl::StrCat("::__pb::PrimitiveMut<'msg, ", RsTypePath(ctx, field),
">");
case RustFieldType::BYTES:
return "::__pb::BytesMut<'msg>";
case RustFieldType::STRING:
return "::__pb::ProtoStrMut<'msg>";
case RustFieldType::MESSAGE:
return absl::StrCat("::__pb::Mut<'msg, ", RsTypePath(ctx, field), ">");
case RustFieldType::ENUM:
return absl::StrCat("::__pb::Mut<'msg, ", RsTypePath(ctx, field), ">");
}
ABSL_LOG(FATAL) << "Unexpected field type: " << field.type_name();
return "";
}
} // namespace
void GenerateOneofDefinition(Context& ctx, const OneofDescriptor& oneof) {
ctx.Emit(
{{"view_enum_name", OneofViewEnumRsName(oneof)},
{"mut_enum_name", OneofMutEnumRsName(oneof)},
{"view_fields",
[&] {
for (int i = 0; i < oneof.field_count(); ++i) {
auto& field = *oneof.field(i);
std::string rs_type = RsTypeNameView(ctx, field);
if (rs_type.empty()) {
continue;
}
ctx.Emit({{"name", OneofCaseRsName(field)},
{"type", rs_type},
{"number", std::to_string(field.number())}},
R"rs($name$($type$) = $number$,
)rs");
}
}},
{"mut_fields",
[&] {
for (int i = 0; i < oneof.field_count(); ++i) {
auto& field = *oneof.field(i);
std::string rs_type = RsTypeNameMut(ctx, field);
if (rs_type.empty()) {
continue;
}
ctx.Emit({{"name", OneofCaseRsName(field)},
{"type", rs_type},
{"number", std::to_string(field.number())}},
R"rs($name$($type$) = $number$,
{
{"view_enum_name", OneofViewEnumRsName(oneof)},
{"view_fields",
[&] {
for (int i = 0; i < oneof.field_count(); ++i) {
auto& field = *oneof.field(i);
std::string rs_type = RsTypeNameView(ctx, field);
if (rs_type.empty()) {
continue;
}
ctx.Emit({{"name", OneofCaseRsName(field)},
{"type", rs_type},
{"number", std::to_string(field.number())}},
R"rs($name$($type$) = $number$,
)rs");
}
}}},
}
}},
},
// TODO: Revisit if isize is the optimal repr for this enum.
// TODO: not_set currently has phantom data just to avoid the
// lifetime on the enum breaking compilation if there are zero supported
@ -190,18 +144,6 @@ void GenerateOneofDefinition(Context& ctx, const OneofDescriptor& oneof) {
#[allow(non_camel_case_types)]
not_set(std::marker::PhantomData<&'msg ()>) = 0
}
#[non_exhaustive]
#[derive(Debug)]
#[allow(dead_code)]
#[repr(isize)]
pub enum $mut_enum_name$<'msg> {
$mut_fields$
#[allow(non_camel_case_types)]
not_set(std::marker::PhantomData<&'msg ()>) = 0
}
)rs");
// Note: This enum is used as the Thunk return type for getting which case is
@ -221,7 +163,7 @@ void GenerateOneofDefinition(Context& ctx, const OneofDescriptor& oneof) {
#[repr(C)]
#[derive(Debug, Copy, Clone, PartialEq, Eq)]
#[allow(dead_code)]
pub(super) enum $case_enum_name$ {
pub enum $case_enum_name$ {
$cases$
#[allow(non_camel_case_types)]
@ -234,92 +176,45 @@ void GenerateOneofDefinition(Context& ctx, const OneofDescriptor& oneof) {
void GenerateOneofAccessors(Context& ctx, const OneofDescriptor& oneof,
AccessorCase accessor_case) {
ctx.Emit(
{{"oneof_name", RsSafeName(oneof.name())},
{"view_lifetime", ViewLifetime(accessor_case)},
{"view_self", ViewReceiver(accessor_case)},
{"view_enum_name", OneofViewEnumRsName(oneof)},
{"mut_enum_name", OneofMutEnumRsName(oneof)},
{"case_enum_name", OneofCaseEnumRsName(oneof)},
{"view_cases",
[&] {
for (int i = 0; i < oneof.field_count(); ++i) {
auto& field = *oneof.field(i);
std::string rs_type = RsTypeNameView(ctx, field);
if (rs_type.empty()) {
continue;
}
ctx.Emit(
{
{"case", OneofCaseRsName(field)},
{"rs_getter", RsSafeName(field.name())},
{"type", rs_type},
},
R"rs(
{
{"oneof_name", RsSafeName(oneof.name())},
{"view_lifetime", ViewLifetime(accessor_case)},
{"self", ViewReceiver(accessor_case)},
{"view_enum_name", OneofViewEnumRsName(oneof)},
{"case_enum_name", OneofCaseEnumRsName(oneof)},
{"view_cases",
[&] {
for (int i = 0; i < oneof.field_count(); ++i) {
auto& field = *oneof.field(i);
std::string rs_type = RsTypeNameView(ctx, field);
if (rs_type.empty()) {
continue;
}
ctx.Emit(
{
{"case", OneofCaseRsName(field)},
{"rs_getter", RsSafeName(field.name())},
{"type", rs_type},
},
R"rs(
$Msg$_::$case_enum_name$::$case$ =>
$Msg$_::$view_enum_name$::$case$(self.$rs_getter$()),
)rs");
}
}},
{"mut_cases",
[&] {
for (int i = 0; i < oneof.field_count(); ++i) {
auto& field = *oneof.field(i);
std::string rs_type = RsTypeNameMut(ctx, field);
if (rs_type.empty()) {
continue;
}
ctx.Emit(
{{"case", OneofCaseRsName(field)},
{"rs_mut_getter", field.name() + "_mut"},
{"type", rs_type}},
// Any extra behavior needed to map the mut getter into the
// unwrapped Mut<>. Right now Message's _mut already returns
// the Mut directly, but for scalars the accessor will return
// an Optional which we then grab the mut by doing
// .try_into_mut().unwrap().
//
// Note that this unwrap() is safe because the flow is:
// 1) Find out which oneof field is already set (if any)
// 2) If a field is set, call the corresponding field's _mut()
// and wrap the result in the SomeOneofMut enum.
// The unwrap() will only ever panic if the which oneof enum
// disagrees with the corresponding field presence which.
R"rs(
$Msg$_::$case_enum_name$::$case$ =>
$Msg$_::$mut_enum_name$::$case$(
self.$rs_mut_getter$().try_into_mut().unwrap()),
)rs");
}
}},
{"case_thunk", ThunkName(ctx, oneof, "case")},
{"getter",
[&] {
ctx.Emit({}, R"rs(
pub fn $oneof_name$($view_self$) -> $Msg$_::$view_enum_name$<$view_lifetime$> {
match unsafe { $case_thunk$(self.raw_msg()) } {
$view_cases$
_ => $Msg$_::$view_enum_name$::not_set(std::marker::PhantomData)
}
}
)rs");
}},
{"getter_mut",
[&] {
if (accessor_case == AccessorCase::VIEW) {
return;
}
ctx.Emit({}, R"rs(
pub fn $oneof_name$_mut(&mut self) -> $Msg$_::$mut_enum_name$ {
match unsafe { $case_thunk$(self.raw_msg()) } {
$mut_cases$
_ => $Msg$_::$mut_enum_name$::not_set(std::marker::PhantomData)
}
}},
{"case_thunk", ThunkName(ctx, oneof, "case")},
},
R"rs(
pub fn $oneof_name$($self$) -> $Msg$_::$view_enum_name$<$view_lifetime$> {
match $self$.$oneof_name$_case() {
$view_cases$
_ => $Msg$_::$view_enum_name$::not_set(std::marker::PhantomData)
}
}
)rs");
}}},
R"rs(
$getter$
$getter_mut$
pub fn $oneof_name$_case($self$) -> $Msg$_::$case_enum_name$ {
unsafe { $case_thunk$(self.raw_msg()) }
}
)rs");
}

Loading…
Cancel
Save