Support enums as map values

PiperOrigin-RevId: 609772950
pull/15914/head
Alyssa Haroldsen 1 year ago committed by Copybara-Service
parent 3590a8b7e8
commit fe5092a392
  1. 3
      rust/test/shared/accessors_map_test.rs
  2. 1
      src/google/protobuf/compiler/rust/BUILD.bazel
  3. 9
      src/google/protobuf/compiler/rust/accessors/accessors.cc
  4. 63
      src/google/protobuf/compiler/rust/accessors/map.cc
  5. 238
      src/google/protobuf/compiler/rust/enum.cc
  6. 4
      src/google/protobuf/compiler/rust/enum.h
  7. 13
      src/google/protobuf/compiler/rust/generator.cc
  8. 30
      src/google/protobuf/compiler/rust/naming.cc
  9. 6
      src/google/protobuf/compiler/rust/naming.h

@ -6,7 +6,7 @@
// https://developers.google.com/open-source/licenses/bsd
use googletest::prelude::*;
use map_unittest_proto::{TestMap, TestMapWithMessages};
use map_unittest_proto::{MapEnum, TestMap, TestMapWithMessages};
use paste::paste;
use std::collections::HashMap;
use unittest_proto::TestAllTypes;
@ -91,6 +91,7 @@ generate_map_primitives_tests!(
(i32, f64, int32, double, 1, 1.),
(bool, bool, bool, bool, true, true),
(i32, &[u8], int32, bytes, 1, b"foo"),
(i32, MapEnum, int32, enum, 1, MapEnum::Baz),
);
#[test]

@ -163,6 +163,7 @@ cc_library(
":context",
":naming",
"//src/google/protobuf",
"//src/google/protobuf/compiler/cpp:names",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/log:absl_check",

@ -34,14 +34,7 @@ std::unique_ptr<AccessorGenerator> AccessorGeneratorFor(
}
if (field.is_map()) {
auto value_type = field.message_type()->map_value()->type();
switch (value_type) {
case FieldDescriptor::TYPE_ENUM:
return std::make_unique<UnsupportedField>(
"Maps with values of type enum are not supported");
default:
return std::make_unique<Map>();
}
return std::make_unique<Map>();
}
if (field.is_repeated()) {

@ -19,6 +19,21 @@ namespace google {
namespace protobuf {
namespace compiler {
namespace rust {
namespace {
std::string MapElementTypeName(const FieldDescriptor& field) {
auto cpp_type = field.cpp_type();
switch (cpp_type) {
case FieldDescriptor::CPPTYPE_MESSAGE:
return cpp::QualifiedClassName(field.message_type());
case FieldDescriptor::CPPTYPE_ENUM:
return cpp::QualifiedClassName(field.enum_type());
default:
return cpp::PrimitiveTypeName(cpp_type);
}
}
} // namespace
void Map::InMsgImpl(Context& ctx, const FieldDescriptor& field,
AccessorCase accessor_case) const {
@ -117,38 +132,24 @@ void Map::InExternC(Context& ctx, const FieldDescriptor& field) const {
)rs");
}
std::string MapElementTypeName(FieldDescriptor::CppType cpp_type,
const Descriptor* message_type) {
if (cpp_type == FieldDescriptor::CPPTYPE_MESSAGE ||
cpp_type == FieldDescriptor::CPPTYPE_ENUM) {
return cpp::QualifiedClassName(message_type);
}
return cpp::PrimitiveTypeName(cpp_type);
}
void Map::InThunkCc(Context& ctx, const FieldDescriptor& field) const {
ctx.Emit(
{{"field", cpp::FieldName(&field)},
{"Key",
MapElementTypeName(field.message_type()->map_key()->cpp_type(),
field.message_type()->map_key()->message_type())},
{"Value",
MapElementTypeName(field.message_type()->map_value()->cpp_type(),
field.message_type()->map_value()->message_type())},
{"QualifiedMsg", cpp::QualifiedClassName(field.containing_type())},
{"getter_thunk", ThunkName(ctx, field, "get")},
{"getter_mut_thunk", ThunkName(ctx, field, "get_mut")},
{"impls",
[&] {
ctx.Emit(
R"cc(
const void* $getter_thunk$(const $QualifiedMsg$* msg) {
return &msg->$field$();
}
void* $getter_mut_thunk$($QualifiedMsg$* msg) { return msg->mutable_$field$(); }
)cc");
}}},
"$impls$");
ctx.Emit({{"field", cpp::FieldName(&field)},
{"Key", MapElementTypeName(*field.message_type()->map_key())},
{"Value", MapElementTypeName(*field.message_type()->map_value())},
{"QualifiedMsg", cpp::QualifiedClassName(field.containing_type())},
{"getter_thunk", ThunkName(ctx, field, "get")},
{"getter_mut_thunk", ThunkName(ctx, field, "get_mut")},
{"impls",
[&] {
ctx.Emit(
R"cc(
const void* $getter_thunk$(const $QualifiedMsg$* msg) {
return &msg->$field$();
}
void* $getter_mut_thunk$($QualifiedMsg$* msg) { return msg->mutable_$field$(); }
)cc");
}}},
"$impls$");
}
} // namespace rust

@ -21,6 +21,7 @@
#include "absl/strings/str_join.h"
#include "absl/strings/string_view.h"
#include "absl/types/span.h"
#include "google/protobuf/compiler/cpp/names.h"
#include "google/protobuf/compiler/rust/context.h"
#include "google/protobuf/compiler/rust/naming.h"
#include "google/protobuf/descriptor.h"
@ -44,6 +45,225 @@ std::vector<std::pair<absl::string_view, int32_t>> EnumValuesInput(
return result;
}
void EnumProxiedInMapValue(Context& ctx, const EnumDescriptor& desc) {
switch (ctx.opts().kernel) {
case Kernel::kCpp:
for (const auto& t : kMapKeyTypes) {
ctx.Emit(
{{"map_new_thunk", RawMapThunk(ctx, desc, t.thunk_ident, "new")},
{"map_free_thunk", RawMapThunk(ctx, desc, t.thunk_ident, "free")},
{"map_clear_thunk",
RawMapThunk(ctx, desc, t.thunk_ident, "clear")},
{"map_size_thunk", RawMapThunk(ctx, desc, t.thunk_ident, "size")},
{"map_insert_thunk",
RawMapThunk(ctx, desc, t.thunk_ident, "insert")},
{"map_get_thunk", RawMapThunk(ctx, desc, t.thunk_ident, "get")},
{"map_remove_thunk",
RawMapThunk(ctx, desc, t.thunk_ident, "remove")},
{"map_iter_thunk", RawMapThunk(ctx, desc, t.thunk_ident, "iter")},
{"map_iter_get_thunk",
RawMapThunk(ctx, desc, t.thunk_ident, "iter_get")},
{"to_ffi_key_expr", t.rs_to_ffi_key_expr},
io::Printer::Sub("ffi_key_t", [&] { ctx.Emit(t.rs_ffi_key_t); })
.WithSuffix(""),
io::Printer::Sub("key_t", [&] { ctx.Emit(t.rs_key_t); })
.WithSuffix(""),
io::Printer::Sub("from_ffi_key_expr",
[&] { ctx.Emit(t.rs_from_ffi_key_expr); })
.WithSuffix("")},
R"rs(
extern "C" {
fn $map_new_thunk$() -> $pbi$::RawMap;
fn $map_free_thunk$(m: $pbi$::RawMap);
fn $map_clear_thunk$(m: $pbi$::RawMap);
fn $map_size_thunk$(m: $pbi$::RawMap) -> usize;
fn $map_insert_thunk$(m: $pbi$::RawMap, key: $ffi_key_t$, value: $name$) -> bool;
fn $map_get_thunk$(m: $pbi$::RawMap, key: $ffi_key_t$, value: *mut $name$) -> bool;
fn $map_remove_thunk$(m: $pbi$::RawMap, key: $ffi_key_t$, value: *mut $name$) -> bool;
fn $map_iter_thunk$(m: $pbi$::RawMap) -> $pbr$::UntypedMapIterator;
fn $map_iter_get_thunk$(iter: &mut $pbr$::UntypedMapIterator, key: *mut $ffi_key_t$, value: *mut $name$);
}
impl $pb$::ProxiedInMapValue<$key_t$> for $name$ {
fn map_new(_private: $pbi$::Private) -> $pb$::Map<$key_t$, Self> {
unsafe {
$pb$::Map::from_inner(
$pbi$::Private,
$pbr$::InnerMapMut::new($pbi$::Private, $map_new_thunk$())
)
}
}
unsafe fn map_free(_private: $pbi$::Private, map: &mut $pb$::Map<$key_t$, Self>) {
unsafe { $map_free_thunk$(map.as_raw($pbi$::Private)); }
}
fn map_clear(mut map: $pb$::Mut<'_, $pb$::Map<$key_t$, Self>>) {
unsafe { $map_clear_thunk$(map.as_raw($pbi$::Private)); }
}
fn map_len(map: $pb$::View<'_, $pb$::Map<$key_t$, Self>>) -> usize {
unsafe { $map_size_thunk$(map.as_raw($pbi$::Private)) }
}
fn map_insert(mut map: $pb$::Mut<'_, $pb$::Map<$key_t$, Self>>, key: $pb$::View<'_, $key_t$>, value: $pb$::View<'_, Self>) -> bool {
unsafe { $map_insert_thunk$(map.as_raw($pbi$::Private), $to_ffi_key_expr$, value) }
}
fn map_get<'a>(map: $pb$::View<'a, $pb$::Map<$key_t$, Self>>, key: $pb$::View<'_, $key_t$>) -> Option<$pb$::View<'a, Self>> {
let key = $to_ffi_key_expr$;
let mut value = $std$::mem::MaybeUninit::uninit();
let found = unsafe { $map_get_thunk$(map.as_raw($pbi$::Private), key, value.as_mut_ptr()) };
if !found {
return None;
}
Some(unsafe { value.assume_init() })
}
fn map_remove(mut map: $pb$::Mut<'_, $pb$::Map<$key_t$, Self>>, key: $pb$::View<'_, $key_t$>) -> bool {
let mut value = $std$::mem::MaybeUninit::uninit();
unsafe { $map_remove_thunk$(map.as_raw($pbi$::Private), $to_ffi_key_expr$, value.as_mut_ptr()) }
}
fn map_iter(map: $pb$::View<'_, $pb$::Map<$key_t$, Self>>) -> $pb$::MapIter<'_, $key_t$, Self> {
// SAFETY:
// - The backing map for `map.as_raw` is valid for at least '_.
// - A View that is live for '_ guarantees the backing map is unmodified for '_.
// - The `iter` function produces an iterator that is valid for the key
// and value types, and live for at least '_.
unsafe {
$pb$::MapIter::from_raw(
$pbi$::Private,
$map_iter_thunk$(map.as_raw($pbi$::Private))
)
}
}
fn map_iter_next<'a>(iter: &mut $pb$::MapIter<'a, $key_t$, Self>) -> Option<($pb$::View<'a, $key_t$>, $pb$::View<'a, Self>)> {
// SAFETY:
// - The `MapIter` API forbids the backing map from being mutated for 'a,
// and guarantees that it's the correct key and value types.
// - The thunk is safe to call as long as the iterator isn't at the end.
// - The thunk always writes to key and value fields and does not read.
// - The thunk does not increment the iterator.
unsafe {
iter.as_raw_mut($pbi$::Private).next_unchecked::<$key_t$, Self, _, _>(
$pbi$::Private,
$map_iter_get_thunk$,
|ffi_key| $from_ffi_key_expr$,
$std$::convert::identity,
)
}
}
}
)rs");
}
return;
case Kernel::kUpb:
for (const auto& t : kMapKeyTypes) {
ctx.Emit({io::Printer::Sub("key_t", [&] { ctx.Emit(t.rs_key_t); })
.WithSuffix("")},
R"rs(
impl $pb$::ProxiedInMapValue<$key_t$> for $name$ {
fn map_new(_private: $pbi$::Private) -> $pb$::Map<$key_t$, Self> {
let arena = $pbr$::Arena::new();
let raw_arena = arena.raw();
std::mem::forget(arena);
unsafe {
$pb$::Map::from_inner(
$pbi$::Private,
$pbr$::InnerMapMut::new(
$pbi$::Private,
$pbr$::upb_Map_New(
raw_arena,
<$key_t$ as $pbr$::UpbTypeConversions>::upb_type(),
$pbr$::UpbCType::Enum),
raw_arena))
}
}
unsafe fn map_free(_private: $pbi$::Private, map: &mut $pb$::Map<$key_t$, Self>) {
// SAFETY:
// - `map.raw_arena($pbi$::Private)` is a live `upb_Arena*`
// - This function is only called once for `map` in `Drop`.
unsafe {
$pbr$::upb_Arena_Free(map.inner($pbi$::Private).raw_arena($pbi$::Private));
}
}
fn map_clear(mut map: $pb$::Mut<'_, $pb$::Map<$key_t$, Self>>) {
unsafe {
$pbr$::upb_Map_Clear(map.as_raw($pbi$::Private));
}
}
fn map_len(map: $pb$::View<'_, $pb$::Map<$key_t$, Self>>) -> usize {
unsafe {
$pbr$::upb_Map_Size(map.as_raw($pbi$::Private))
}
}
fn map_insert(mut map: $pb$::Mut<'_, $pb$::Map<$key_t$, Self>>, key: $pb$::View<'_, $key_t$>, value: $pb$::View<'_, Self>) -> bool {
let arena = map.inner($pbi$::Private).raw_arena($pbi$::Private);
unsafe {
$pbr$::upb_Map_InsertAndReturnIfInserted(
map.as_raw($pbi$::Private),
<$key_t$ as $pbr$::UpbTypeConversions>::to_message_value(key),
$pbr$::upb_MessageValue { int32_val: value.0 },
arena
)
}
}
fn map_get<'a>(map: $pb$::View<'a, $pb$::Map<$key_t$, Self>>, key: $pb$::View<'_, $key_t$>) -> Option<$pb$::View<'a, Self>> {
let mut val = $std$::mem::MaybeUninit::uninit();
let found = unsafe {
$pbr$::upb_Map_Get(
map.as_raw($pbi$::Private),
<$key_t$ as $pbr$::UpbTypeConversions>::to_message_value(key),
val.as_mut_ptr())
};
if !found {
return None;
}
Some($name$(unsafe { val.assume_init().int32_val }))
}
fn map_remove(mut map: $pb$::Mut<'_, $pb$::Map<$key_t$, Self>>, key: $pb$::View<'_, $key_t$>) -> bool {
let mut val = $std$::mem::MaybeUninit::uninit();
unsafe {
$pbr$::upb_Map_Delete(
map.as_raw($pbi$::Private),
<$key_t$ as $pbr$::UpbTypeConversions>::to_message_value(key),
val.as_mut_ptr())
}
}
fn map_iter(map: $pb$::View<'_, $pb$::Map<$key_t$, Self>>) -> $pb$::MapIter<'_, $key_t$, Self> {
// SAFETY: View<Map<'_,..>> guarantees its RawMap outlives '_.
unsafe {
$pb$::MapIter::from_raw($pbi$::Private, $pbr$::RawMapIter::new($pbi$::Private, map.as_raw($pbi$::Private)))
}
}
fn map_iter_next<'a>(
iter: &mut $pb$::MapIter<'a, $key_t$, Self>
) -> Option<($pb$::View<'a, $key_t$>, $pb$::View<'a, Self>)> {
// SAFETY: MapIter<'a, ..> guarantees its RawMapIter outlives 'a.
unsafe { iter.as_raw_mut($pbi$::Private).next_unchecked($pbi$::Private) }
// SAFETY: MapIter<K, V> returns key and values message values
// with the variants for K and V active.
.map(|(k, v)| unsafe {(
<$key_t$ as $pbr$::UpbTypeConversions>::from_message_value(k),
Self(v.int32_val),
)})
}
}
)rs");
}
return;
}
}
} // namespace
std::vector<RustEnumValue> EnumValues(
@ -149,6 +369,7 @@ void GenerateEnumDefinition(Context& ctx, const EnumDescriptor& desc) {
)rs");
}
}},
{"impl_proxied_in_map", [&] { EnumProxiedInMapValue(ctx, desc); }},
},
R"rs(
#[repr(transparent)]
@ -277,9 +498,26 @@ void GenerateEnumDefinition(Context& ctx, const EnumDescriptor& desc) {
unsafe impl $pbi$::Enum for $name$ {
const NAME: &'static str = "$name$";
}
$impl_proxied_in_map$
)rs");
}
void GenerateEnumThunksCc(Context& ctx, const EnumDescriptor& desc) {
ctx.Emit(
{
{"cpp_t", cpp::QualifiedClassName(&desc)},
{"rs_t", GetUnderscoreDelimitedFullName(ctx, desc)},
{"abi", "\"C\""}, // Workaround for syntax highlight bug in VSCode.
},
R"cc(
extern $abi$ {
__PB_RUST_EXPOSE_SCALAR_MAP_METHODS_FOR_VALUE_TYPE(
$cpp_t$, $rs_t$, $cpp_t$, value, cpp_value)
}
)cc");
}
} // namespace rust
} // namespace compiler
} // namespace protobuf

@ -23,8 +23,12 @@ namespace protobuf {
namespace compiler {
namespace rust {
// Generates code for a particular enum in `.pb.rs`.
void GenerateEnumDefinition(Context& ctx, const EnumDescriptor& desc);
// Generates code for a particular enum in `.pb.thunk.cc`.
void GenerateEnumThunksCc(Context& ctx, const EnumDescriptor& desc);
// An enum value with a unique number and any aliases for it.
struct RustEnumValue {
// The canonical CamelCase name in Rust.

@ -227,8 +227,19 @@ bool RustGenerator::Generate(const FileDescriptor* file,
}
for (int i = 0; i < file->enum_type_count(); ++i) {
GenerateEnumDefinition(ctx, *file->enum_type(i));
auto& enum_ = *file->enum_type(i);
GenerateEnumDefinition(ctx, enum_);
ctx.printer().PrintRaw("\n");
if (ctx.is_cpp()) {
auto thunks_ctx = ctx.WithPrinter(thunks_printer.get());
thunks_ctx.Emit({{"enum", enum_.full_name()}}, R"cc(
// $enum$
)cc");
GenerateEnumThunksCc(thunks_ctx, enum_);
thunks_ctx.printer().PrintRaw("\n");
}
}
return true;

@ -35,14 +35,6 @@ namespace google {
namespace protobuf {
namespace compiler {
namespace rust {
namespace {
std::string GetUnderscoreDelimitedFullName(Context& ctx,
const Descriptor& msg) {
std::string result = msg.full_name();
absl::StrReplaceAll({{".", "_"}}, &result);
return result;
}
} // namespace
std::string GetCrateName(Context& ctx, const FileDescriptor& dep) {
return RsSafeName(ctx.generator_context().ImportPathToCrateName(dep.name()));
@ -73,10 +65,16 @@ std::string GetHeaderFile(Context& ctx, const FileDescriptor& file) {
std::string RawMapThunk(Context& ctx, const Descriptor& msg,
absl::string_view key_t, absl::string_view op) {
return absl::StrCat("__rust_proto_thunk__", key_t, "_",
return absl::StrCat("__rust_proto_thunk__Map_", key_t, "_",
GetUnderscoreDelimitedFullName(ctx, *&msg), "_", op);
}
std::string RawMapThunk(Context& ctx, const EnumDescriptor& desc,
absl::string_view key_t, absl::string_view op) {
return absl::StrCat("__rust_proto_thunk__Map_", key_t, "_",
GetUnderscoreDelimitedFullName(ctx, *&desc), "_", op);
}
namespace {
template <typename T>
@ -197,6 +195,20 @@ std::string GetFullyQualifiedPath(Context& ctx, const EnumDescriptor& enum_) {
return absl::StrCat(GetCrateName(ctx, *enum_.file()), "::", rel_path);
}
std::string GetUnderscoreDelimitedFullName(Context& ctx,
const Descriptor& msg) {
std::string result = msg.full_name();
absl::StrReplaceAll({{".", "_"}}, &result);
return result;
}
std::string GetUnderscoreDelimitedFullName(Context& ctx,
const EnumDescriptor& enum_) {
std::string result = enum_.full_name();
absl::StrReplaceAll({{".", "_"}}, &result);
return result;
}
std::string RsTypePath(Context& ctx, const FieldDescriptor& field) {
switch (GetRustFieldType(field)) {
case RustFieldType::BOOL:

@ -35,6 +35,8 @@ std::string ThunkName(Context& ctx, const Descriptor& msg,
absl::string_view op);
std::string RawMapThunk(Context& ctx, const Descriptor& msg,
absl::string_view key_t, absl::string_view op);
std::string RawMapThunk(Context& ctx, const EnumDescriptor& desc,
absl::string_view key_t, absl::string_view op);
// Returns the local constant that defines the vtable for mutating `field`.
std::string VTableName(const FieldDescriptor& field);
@ -75,6 +77,10 @@ std::string GetCrateRelativeQualifiedPath(Context& ctx, const Descriptor& msg);
std::string GetCrateRelativeQualifiedPath(Context& ctx,
const EnumDescriptor& enum_);
std::string GetUnderscoreDelimitedFullName(Context& ctx, const Descriptor& msg);
std::string GetUnderscoreDelimitedFullName(Context& ctx,
const EnumDescriptor& enum_);
// TODO: Unify these with other case-conversion functions.
// Converts an UpperCamel or lowerCamel string to a snake_case string.

Loading…
Cancel
Save