diff --git a/rust/test/shared/accessors_map_test.rs b/rust/test/shared/accessors_map_test.rs index 75730b67b8..930b1e5fe7 100644 --- a/rust/test/shared/accessors_map_test.rs +++ b/rust/test/shared/accessors_map_test.rs @@ -6,7 +6,7 @@ // https://developers.google.com/open-source/licenses/bsd use googletest::prelude::*; -use map_unittest_proto::{TestMap, TestMapWithMessages}; +use map_unittest_proto::{MapEnum, TestMap, TestMapWithMessages}; use paste::paste; use std::collections::HashMap; use unittest_proto::TestAllTypes; @@ -91,6 +91,7 @@ generate_map_primitives_tests!( (i32, f64, int32, double, 1, 1.), (bool, bool, bool, bool, true, true), (i32, &[u8], int32, bytes, 1, b"foo"), + (i32, MapEnum, int32, enum, 1, MapEnum::Baz), ); #[test] diff --git a/src/google/protobuf/compiler/rust/BUILD.bazel b/src/google/protobuf/compiler/rust/BUILD.bazel index acaceb7dbb..6b6f98754c 100644 --- a/src/google/protobuf/compiler/rust/BUILD.bazel +++ b/src/google/protobuf/compiler/rust/BUILD.bazel @@ -163,6 +163,7 @@ cc_library( ":context", ":naming", "//src/google/protobuf", + "//src/google/protobuf/compiler/cpp:names", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/log:absl_check", diff --git a/src/google/protobuf/compiler/rust/accessors/accessors.cc b/src/google/protobuf/compiler/rust/accessors/accessors.cc index 886136dedb..9aa654577f 100644 --- a/src/google/protobuf/compiler/rust/accessors/accessors.cc +++ b/src/google/protobuf/compiler/rust/accessors/accessors.cc @@ -34,14 +34,7 @@ std::unique_ptr AccessorGeneratorFor( } if (field.is_map()) { - auto value_type = field.message_type()->map_value()->type(); - switch (value_type) { - case FieldDescriptor::TYPE_ENUM: - return std::make_unique( - "Maps with values of type enum are not supported"); - default: - return std::make_unique(); - } + return std::make_unique(); } if (field.is_repeated()) { diff --git a/src/google/protobuf/compiler/rust/accessors/map.cc b/src/google/protobuf/compiler/rust/accessors/map.cc index 2818ec57fd..56e49abcd1 100644 --- a/src/google/protobuf/compiler/rust/accessors/map.cc +++ b/src/google/protobuf/compiler/rust/accessors/map.cc @@ -19,6 +19,21 @@ namespace google { namespace protobuf { namespace compiler { namespace rust { +namespace { + +std::string MapElementTypeName(const FieldDescriptor& field) { + auto cpp_type = field.cpp_type(); + switch (cpp_type) { + case FieldDescriptor::CPPTYPE_MESSAGE: + return cpp::QualifiedClassName(field.message_type()); + case FieldDescriptor::CPPTYPE_ENUM: + return cpp::QualifiedClassName(field.enum_type()); + default: + return cpp::PrimitiveTypeName(cpp_type); + } +} + +} // namespace void Map::InMsgImpl(Context& ctx, const FieldDescriptor& field, AccessorCase accessor_case) const { @@ -117,38 +132,24 @@ void Map::InExternC(Context& ctx, const FieldDescriptor& field) const { )rs"); } -std::string MapElementTypeName(FieldDescriptor::CppType cpp_type, - const Descriptor* message_type) { - if (cpp_type == FieldDescriptor::CPPTYPE_MESSAGE || - cpp_type == FieldDescriptor::CPPTYPE_ENUM) { - return cpp::QualifiedClassName(message_type); - } - return cpp::PrimitiveTypeName(cpp_type); -} - void Map::InThunkCc(Context& ctx, const FieldDescriptor& field) const { - ctx.Emit( - {{"field", cpp::FieldName(&field)}, - {"Key", - MapElementTypeName(field.message_type()->map_key()->cpp_type(), - field.message_type()->map_key()->message_type())}, - {"Value", - MapElementTypeName(field.message_type()->map_value()->cpp_type(), - field.message_type()->map_value()->message_type())}, - {"QualifiedMsg", cpp::QualifiedClassName(field.containing_type())}, - {"getter_thunk", ThunkName(ctx, field, "get")}, - {"getter_mut_thunk", ThunkName(ctx, field, "get_mut")}, - {"impls", - [&] { - ctx.Emit( - R"cc( - const void* $getter_thunk$(const $QualifiedMsg$* msg) { - return &msg->$field$(); - } - void* $getter_mut_thunk$($QualifiedMsg$* msg) { return msg->mutable_$field$(); } - )cc"); - }}}, - "$impls$"); + ctx.Emit({{"field", cpp::FieldName(&field)}, + {"Key", MapElementTypeName(*field.message_type()->map_key())}, + {"Value", MapElementTypeName(*field.message_type()->map_value())}, + {"QualifiedMsg", cpp::QualifiedClassName(field.containing_type())}, + {"getter_thunk", ThunkName(ctx, field, "get")}, + {"getter_mut_thunk", ThunkName(ctx, field, "get_mut")}, + {"impls", + [&] { + ctx.Emit( + R"cc( + const void* $getter_thunk$(const $QualifiedMsg$* msg) { + return &msg->$field$(); + } + void* $getter_mut_thunk$($QualifiedMsg$* msg) { return msg->mutable_$field$(); } + )cc"); + }}}, + "$impls$"); } } // namespace rust diff --git a/src/google/protobuf/compiler/rust/enum.cc b/src/google/protobuf/compiler/rust/enum.cc index 7539dedfc8..7154ed0e96 100644 --- a/src/google/protobuf/compiler/rust/enum.cc +++ b/src/google/protobuf/compiler/rust/enum.cc @@ -21,6 +21,7 @@ #include "absl/strings/str_join.h" #include "absl/strings/string_view.h" #include "absl/types/span.h" +#include "google/protobuf/compiler/cpp/names.h" #include "google/protobuf/compiler/rust/context.h" #include "google/protobuf/compiler/rust/naming.h" #include "google/protobuf/descriptor.h" @@ -44,6 +45,225 @@ std::vector> EnumValuesInput( return result; } +void EnumProxiedInMapValue(Context& ctx, const EnumDescriptor& desc) { + switch (ctx.opts().kernel) { + case Kernel::kCpp: + for (const auto& t : kMapKeyTypes) { + ctx.Emit( + {{"map_new_thunk", RawMapThunk(ctx, desc, t.thunk_ident, "new")}, + {"map_free_thunk", RawMapThunk(ctx, desc, t.thunk_ident, "free")}, + {"map_clear_thunk", + RawMapThunk(ctx, desc, t.thunk_ident, "clear")}, + {"map_size_thunk", RawMapThunk(ctx, desc, t.thunk_ident, "size")}, + {"map_insert_thunk", + RawMapThunk(ctx, desc, t.thunk_ident, "insert")}, + {"map_get_thunk", RawMapThunk(ctx, desc, t.thunk_ident, "get")}, + {"map_remove_thunk", + RawMapThunk(ctx, desc, t.thunk_ident, "remove")}, + {"map_iter_thunk", RawMapThunk(ctx, desc, t.thunk_ident, "iter")}, + {"map_iter_get_thunk", + RawMapThunk(ctx, desc, t.thunk_ident, "iter_get")}, + {"to_ffi_key_expr", t.rs_to_ffi_key_expr}, + io::Printer::Sub("ffi_key_t", [&] { ctx.Emit(t.rs_ffi_key_t); }) + .WithSuffix(""), + io::Printer::Sub("key_t", [&] { ctx.Emit(t.rs_key_t); }) + .WithSuffix(""), + io::Printer::Sub("from_ffi_key_expr", + [&] { ctx.Emit(t.rs_from_ffi_key_expr); }) + .WithSuffix("")}, + R"rs( + extern "C" { + fn $map_new_thunk$() -> $pbi$::RawMap; + fn $map_free_thunk$(m: $pbi$::RawMap); + fn $map_clear_thunk$(m: $pbi$::RawMap); + fn $map_size_thunk$(m: $pbi$::RawMap) -> usize; + fn $map_insert_thunk$(m: $pbi$::RawMap, key: $ffi_key_t$, value: $name$) -> bool; + fn $map_get_thunk$(m: $pbi$::RawMap, key: $ffi_key_t$, value: *mut $name$) -> bool; + fn $map_remove_thunk$(m: $pbi$::RawMap, key: $ffi_key_t$, value: *mut $name$) -> bool; + fn $map_iter_thunk$(m: $pbi$::RawMap) -> $pbr$::UntypedMapIterator; + fn $map_iter_get_thunk$(iter: &mut $pbr$::UntypedMapIterator, key: *mut $ffi_key_t$, value: *mut $name$); + } + impl $pb$::ProxiedInMapValue<$key_t$> for $name$ { + fn map_new(_private: $pbi$::Private) -> $pb$::Map<$key_t$, Self> { + unsafe { + $pb$::Map::from_inner( + $pbi$::Private, + $pbr$::InnerMapMut::new($pbi$::Private, $map_new_thunk$()) + ) + } + } + + unsafe fn map_free(_private: $pbi$::Private, map: &mut $pb$::Map<$key_t$, Self>) { + unsafe { $map_free_thunk$(map.as_raw($pbi$::Private)); } + } + + fn map_clear(mut map: $pb$::Mut<'_, $pb$::Map<$key_t$, Self>>) { + unsafe { $map_clear_thunk$(map.as_raw($pbi$::Private)); } + } + + fn map_len(map: $pb$::View<'_, $pb$::Map<$key_t$, Self>>) -> usize { + unsafe { $map_size_thunk$(map.as_raw($pbi$::Private)) } + } + + fn map_insert(mut map: $pb$::Mut<'_, $pb$::Map<$key_t$, Self>>, key: $pb$::View<'_, $key_t$>, value: $pb$::View<'_, Self>) -> bool { + unsafe { $map_insert_thunk$(map.as_raw($pbi$::Private), $to_ffi_key_expr$, value) } + } + + fn map_get<'a>(map: $pb$::View<'a, $pb$::Map<$key_t$, Self>>, key: $pb$::View<'_, $key_t$>) -> Option<$pb$::View<'a, Self>> { + let key = $to_ffi_key_expr$; + let mut value = $std$::mem::MaybeUninit::uninit(); + let found = unsafe { $map_get_thunk$(map.as_raw($pbi$::Private), key, value.as_mut_ptr()) }; + if !found { + return None; + } + Some(unsafe { value.assume_init() }) + } + + fn map_remove(mut map: $pb$::Mut<'_, $pb$::Map<$key_t$, Self>>, key: $pb$::View<'_, $key_t$>) -> bool { + let mut value = $std$::mem::MaybeUninit::uninit(); + unsafe { $map_remove_thunk$(map.as_raw($pbi$::Private), $to_ffi_key_expr$, value.as_mut_ptr()) } + } + + fn map_iter(map: $pb$::View<'_, $pb$::Map<$key_t$, Self>>) -> $pb$::MapIter<'_, $key_t$, Self> { + // SAFETY: + // - The backing map for `map.as_raw` is valid for at least '_. + // - A View that is live for '_ guarantees the backing map is unmodified for '_. + // - The `iter` function produces an iterator that is valid for the key + // and value types, and live for at least '_. + unsafe { + $pb$::MapIter::from_raw( + $pbi$::Private, + $map_iter_thunk$(map.as_raw($pbi$::Private)) + ) + } + } + + fn map_iter_next<'a>(iter: &mut $pb$::MapIter<'a, $key_t$, Self>) -> Option<($pb$::View<'a, $key_t$>, $pb$::View<'a, Self>)> { + // SAFETY: + // - The `MapIter` API forbids the backing map from being mutated for 'a, + // and guarantees that it's the correct key and value types. + // - The thunk is safe to call as long as the iterator isn't at the end. + // - The thunk always writes to key and value fields and does not read. + // - The thunk does not increment the iterator. + unsafe { + iter.as_raw_mut($pbi$::Private).next_unchecked::<$key_t$, Self, _, _>( + $pbi$::Private, + $map_iter_get_thunk$, + |ffi_key| $from_ffi_key_expr$, + $std$::convert::identity, + ) + } + } + } + )rs"); + } + return; + case Kernel::kUpb: + for (const auto& t : kMapKeyTypes) { + ctx.Emit({io::Printer::Sub("key_t", [&] { ctx.Emit(t.rs_key_t); }) + .WithSuffix("")}, + R"rs( + impl $pb$::ProxiedInMapValue<$key_t$> for $name$ { + fn map_new(_private: $pbi$::Private) -> $pb$::Map<$key_t$, Self> { + let arena = $pbr$::Arena::new(); + let raw_arena = arena.raw(); + std::mem::forget(arena); + + unsafe { + $pb$::Map::from_inner( + $pbi$::Private, + $pbr$::InnerMapMut::new( + $pbi$::Private, + $pbr$::upb_Map_New( + raw_arena, + <$key_t$ as $pbr$::UpbTypeConversions>::upb_type(), + $pbr$::UpbCType::Enum), + raw_arena)) + } + } + + unsafe fn map_free(_private: $pbi$::Private, map: &mut $pb$::Map<$key_t$, Self>) { + // SAFETY: + // - `map.raw_arena($pbi$::Private)` is a live `upb_Arena*` + // - This function is only called once for `map` in `Drop`. + unsafe { + $pbr$::upb_Arena_Free(map.inner($pbi$::Private).raw_arena($pbi$::Private)); + } + } + + fn map_clear(mut map: $pb$::Mut<'_, $pb$::Map<$key_t$, Self>>) { + unsafe { + $pbr$::upb_Map_Clear(map.as_raw($pbi$::Private)); + } + } + + fn map_len(map: $pb$::View<'_, $pb$::Map<$key_t$, Self>>) -> usize { + unsafe { + $pbr$::upb_Map_Size(map.as_raw($pbi$::Private)) + } + } + + fn map_insert(mut map: $pb$::Mut<'_, $pb$::Map<$key_t$, Self>>, key: $pb$::View<'_, $key_t$>, value: $pb$::View<'_, Self>) -> bool { + let arena = map.inner($pbi$::Private).raw_arena($pbi$::Private); + unsafe { + $pbr$::upb_Map_InsertAndReturnIfInserted( + map.as_raw($pbi$::Private), + <$key_t$ as $pbr$::UpbTypeConversions>::to_message_value(key), + $pbr$::upb_MessageValue { int32_val: value.0 }, + arena + ) + } + } + + fn map_get<'a>(map: $pb$::View<'a, $pb$::Map<$key_t$, Self>>, key: $pb$::View<'_, $key_t$>) -> Option<$pb$::View<'a, Self>> { + let mut val = $std$::mem::MaybeUninit::uninit(); + let found = unsafe { + $pbr$::upb_Map_Get( + map.as_raw($pbi$::Private), + <$key_t$ as $pbr$::UpbTypeConversions>::to_message_value(key), + val.as_mut_ptr()) + }; + if !found { + return None; + } + Some($name$(unsafe { val.assume_init().int32_val })) + } + + fn map_remove(mut map: $pb$::Mut<'_, $pb$::Map<$key_t$, Self>>, key: $pb$::View<'_, $key_t$>) -> bool { + let mut val = $std$::mem::MaybeUninit::uninit(); + unsafe { + $pbr$::upb_Map_Delete( + map.as_raw($pbi$::Private), + <$key_t$ as $pbr$::UpbTypeConversions>::to_message_value(key), + val.as_mut_ptr()) + } + } + fn map_iter(map: $pb$::View<'_, $pb$::Map<$key_t$, Self>>) -> $pb$::MapIter<'_, $key_t$, Self> { + // SAFETY: View> guarantees its RawMap outlives '_. + unsafe { + $pb$::MapIter::from_raw($pbi$::Private, $pbr$::RawMapIter::new($pbi$::Private, map.as_raw($pbi$::Private))) + } + } + + fn map_iter_next<'a>( + iter: &mut $pb$::MapIter<'a, $key_t$, Self> + ) -> Option<($pb$::View<'a, $key_t$>, $pb$::View<'a, Self>)> { + // SAFETY: MapIter<'a, ..> guarantees its RawMapIter outlives 'a. + unsafe { iter.as_raw_mut($pbi$::Private).next_unchecked($pbi$::Private) } + // SAFETY: MapIter returns key and values message values + // with the variants for K and V active. + .map(|(k, v)| unsafe {( + <$key_t$ as $pbr$::UpbTypeConversions>::from_message_value(k), + Self(v.int32_val), + )}) + } + } + )rs"); + } + return; + } +} + } // namespace std::vector EnumValues( @@ -149,6 +369,7 @@ void GenerateEnumDefinition(Context& ctx, const EnumDescriptor& desc) { )rs"); } }}, + {"impl_proxied_in_map", [&] { EnumProxiedInMapValue(ctx, desc); }}, }, R"rs( #[repr(transparent)] @@ -277,9 +498,26 @@ void GenerateEnumDefinition(Context& ctx, const EnumDescriptor& desc) { unsafe impl $pbi$::Enum for $name$ { const NAME: &'static str = "$name$"; } + + $impl_proxied_in_map$ )rs"); } +void GenerateEnumThunksCc(Context& ctx, const EnumDescriptor& desc) { + ctx.Emit( + { + {"cpp_t", cpp::QualifiedClassName(&desc)}, + {"rs_t", GetUnderscoreDelimitedFullName(ctx, desc)}, + {"abi", "\"C\""}, // Workaround for syntax highlight bug in VSCode. + }, + R"cc( + extern $abi$ { + __PB_RUST_EXPOSE_SCALAR_MAP_METHODS_FOR_VALUE_TYPE( + $cpp_t$, $rs_t$, $cpp_t$, value, cpp_value) + } + )cc"); +} + } // namespace rust } // namespace compiler } // namespace protobuf diff --git a/src/google/protobuf/compiler/rust/enum.h b/src/google/protobuf/compiler/rust/enum.h index 4c04b51cb8..28dcdcbcf9 100644 --- a/src/google/protobuf/compiler/rust/enum.h +++ b/src/google/protobuf/compiler/rust/enum.h @@ -23,8 +23,12 @@ namespace protobuf { namespace compiler { namespace rust { +// Generates code for a particular enum in `.pb.rs`. void GenerateEnumDefinition(Context& ctx, const EnumDescriptor& desc); +// Generates code for a particular enum in `.pb.thunk.cc`. +void GenerateEnumThunksCc(Context& ctx, const EnumDescriptor& desc); + // An enum value with a unique number and any aliases for it. struct RustEnumValue { // The canonical CamelCase name in Rust. diff --git a/src/google/protobuf/compiler/rust/generator.cc b/src/google/protobuf/compiler/rust/generator.cc index 5239bec56c..71a8c1b107 100644 --- a/src/google/protobuf/compiler/rust/generator.cc +++ b/src/google/protobuf/compiler/rust/generator.cc @@ -227,8 +227,19 @@ bool RustGenerator::Generate(const FileDescriptor* file, } for (int i = 0; i < file->enum_type_count(); ++i) { - GenerateEnumDefinition(ctx, *file->enum_type(i)); + auto& enum_ = *file->enum_type(i); + GenerateEnumDefinition(ctx, enum_); ctx.printer().PrintRaw("\n"); + + if (ctx.is_cpp()) { + auto thunks_ctx = ctx.WithPrinter(thunks_printer.get()); + + thunks_ctx.Emit({{"enum", enum_.full_name()}}, R"cc( + // $enum$ + )cc"); + GenerateEnumThunksCc(thunks_ctx, enum_); + thunks_ctx.printer().PrintRaw("\n"); + } } return true; diff --git a/src/google/protobuf/compiler/rust/naming.cc b/src/google/protobuf/compiler/rust/naming.cc index 4b35d042f9..39ff326ff8 100644 --- a/src/google/protobuf/compiler/rust/naming.cc +++ b/src/google/protobuf/compiler/rust/naming.cc @@ -35,14 +35,6 @@ namespace google { namespace protobuf { namespace compiler { namespace rust { -namespace { -std::string GetUnderscoreDelimitedFullName(Context& ctx, - const Descriptor& msg) { - std::string result = msg.full_name(); - absl::StrReplaceAll({{".", "_"}}, &result); - return result; -} -} // namespace std::string GetCrateName(Context& ctx, const FileDescriptor& dep) { return RsSafeName(ctx.generator_context().ImportPathToCrateName(dep.name())); @@ -73,10 +65,16 @@ std::string GetHeaderFile(Context& ctx, const FileDescriptor& file) { std::string RawMapThunk(Context& ctx, const Descriptor& msg, absl::string_view key_t, absl::string_view op) { - return absl::StrCat("__rust_proto_thunk__", key_t, "_", + return absl::StrCat("__rust_proto_thunk__Map_", key_t, "_", GetUnderscoreDelimitedFullName(ctx, *&msg), "_", op); } +std::string RawMapThunk(Context& ctx, const EnumDescriptor& desc, + absl::string_view key_t, absl::string_view op) { + return absl::StrCat("__rust_proto_thunk__Map_", key_t, "_", + GetUnderscoreDelimitedFullName(ctx, *&desc), "_", op); +} + namespace { template @@ -197,6 +195,20 @@ std::string GetFullyQualifiedPath(Context& ctx, const EnumDescriptor& enum_) { return absl::StrCat(GetCrateName(ctx, *enum_.file()), "::", rel_path); } +std::string GetUnderscoreDelimitedFullName(Context& ctx, + const Descriptor& msg) { + std::string result = msg.full_name(); + absl::StrReplaceAll({{".", "_"}}, &result); + return result; +} + +std::string GetUnderscoreDelimitedFullName(Context& ctx, + const EnumDescriptor& enum_) { + std::string result = enum_.full_name(); + absl::StrReplaceAll({{".", "_"}}, &result); + return result; +} + std::string RsTypePath(Context& ctx, const FieldDescriptor& field) { switch (GetRustFieldType(field)) { case RustFieldType::BOOL: diff --git a/src/google/protobuf/compiler/rust/naming.h b/src/google/protobuf/compiler/rust/naming.h index 7a26343641..2928c5dccb 100644 --- a/src/google/protobuf/compiler/rust/naming.h +++ b/src/google/protobuf/compiler/rust/naming.h @@ -35,6 +35,8 @@ std::string ThunkName(Context& ctx, const Descriptor& msg, absl::string_view op); std::string RawMapThunk(Context& ctx, const Descriptor& msg, absl::string_view key_t, absl::string_view op); +std::string RawMapThunk(Context& ctx, const EnumDescriptor& desc, + absl::string_view key_t, absl::string_view op); // Returns the local constant that defines the vtable for mutating `field`. std::string VTableName(const FieldDescriptor& field); @@ -75,6 +77,10 @@ std::string GetCrateRelativeQualifiedPath(Context& ctx, const Descriptor& msg); std::string GetCrateRelativeQualifiedPath(Context& ctx, const EnumDescriptor& enum_); +std::string GetUnderscoreDelimitedFullName(Context& ctx, const Descriptor& msg); +std::string GetUnderscoreDelimitedFullName(Context& ctx, + const EnumDescriptor& enum_); + // TODO: Unify these with other case-conversion functions. // Converts an UpperCamel or lowerCamel string to a snake_case string.