From 09ac8ca24887692e83e88e578a93f9289fc501f0 Mon Sep 17 00:00:00 2001 From: Hong Shin Date: Thu, 10 Oct 2024 08:04:48 -0700 Subject: [PATCH] hpb: Introduce set_alias on maps for (k,v) where v is a message PiperOrigin-RevId: 684446420 --- hpb_generator/gen_accessors.cc | 33 +++++++++++++++++++++++++++ hpb_generator/tests/set_alias.proto | 4 ++++ hpb_generator/tests/test_generated.cc | 20 ++++++++++++++++ 3 files changed, 57 insertions(+) diff --git a/hpb_generator/gen_accessors.cc b/hpb_generator/gen_accessors.cc index ef3f327a62..9a25a2b6ac 100644 --- a/hpb_generator/gen_accessors.cc +++ b/hpb_generator/gen_accessors.cc @@ -181,6 +181,8 @@ void WriteMapFieldAccessors(const protobuf::Descriptor* desc, R"cc( bool set_$0($1 key, $3 value); bool set_$0($1 key, $4 value); + bool set_alias_$0($1 key, $3 value); + bool set_alias_$0($1 key, $4 value); absl::StatusOr<$3> get_$0($1 key); )cc", resolved_field_name, CppConstType(key), CppConstType(val), @@ -338,6 +340,28 @@ void WriteMapAccessorDefinitions(const protobuf::Descriptor* message, converted_key_name, upbc_name, ::upb::generator::MiniTableMessageVarName( val->message_type()->full_name())); + output( + R"cc( + bool $0::set_alias_$1($2 key, $3 value) { + $6return $4_$8_set( + msg_, $7, ($5*)hpb::interop::upb::GetMessage(value), arena_); + } + )cc", + class_name, resolved_field_name, CppConstType(key), + MessagePtrConstType(val, /* is_const */ true), MessageName(message), + MessageName(val->message_type()), optional_conversion_code, + converted_key_name, upbc_name); + output( + R"cc( + bool $0::set_alias_$1($2 key, $3 value) { + $6return $4_$8_set( + msg_, $7, ($5*)hpb::interop::upb::GetMessage(value), arena_); + } + )cc", + class_name, resolved_field_name, CppConstType(key), + MessagePtrConstType(val, /* is_const */ false), MessageName(message), + MessageName(val->message_type()), optional_conversion_code, + converted_key_name, upbc_name); output( R"cc( absl::StatusOr<$3> $0::get_$1($2 key) { @@ -462,6 +486,15 @@ void WriteUsingAccessorsInHeader(const protobuf::Descriptor* desc, using $0Access::set_$1; )cc", class_name, resolved_field_name); + // only emit set_alias for maps when value is a message + if (field->message_type()->FindFieldByNumber(2)->cpp_type() == + protobuf::FieldDescriptor::CPPTYPE_MESSAGE) { + output( + R"cc( + using $0Access::set_alias_$1; + )cc", + class_name, resolved_field_name); + } } } else if (desc->options().map_entry()) { // TODO Implement map entry diff --git a/hpb_generator/tests/set_alias.proto b/hpb_generator/tests/set_alias.proto index 4b53fb1a8a..a096e6d2b1 100644 --- a/hpb_generator/tests/set_alias.proto +++ b/hpb_generator/tests/set_alias.proto @@ -21,3 +21,7 @@ message Parent { message ParentWithRepeated { repeated Child children = 1; } + +message ParentWithMap { + map child_map = 1; +} diff --git a/hpb_generator/tests/test_generated.cc b/hpb_generator/tests/test_generated.cc index d17904137f..a7bb538dc1 100644 --- a/hpb_generator/tests/test_generated.cc +++ b/hpb_generator/tests/test_generated.cc @@ -41,6 +41,7 @@ using ::hpb_unittest::protos::container_ext; using ::hpb_unittest::protos::ContainerExtension; using ::hpb_unittest::protos::other_ext; using ::hpb_unittest::protos::Parent; +using ::hpb_unittest::protos::ParentWithMap; using ::hpb_unittest::protos::ParentWithRepeated; using ::hpb_unittest::protos::RED; using ::hpb_unittest::protos::TestEnum; @@ -1299,4 +1300,23 @@ TEST(CppGeneratedCode, SetAliasRepeatedFailsForDifferentArena) { EXPECT_DEATH(parent.add_alias_children(child), "hpb::interop::upb::GetArena"); } +TEST(CppGeneratedCode, SetAliasMap) { + hpb::Arena arena; + auto parent1 = hpb::CreateMessage(arena); + auto parent2 = hpb::CreateMessage(arena); + + auto child = hpb::CreateMessage(arena); + + constexpr int key = 1; + parent1.set_alias_child_map(key, child); + parent2.set_alias_child_map(key, child); + auto c1 = parent1.get_child_map(key); + auto c2 = parent2.get_child_map(key); + + EXPECT_TRUE(c1.ok()); + EXPECT_TRUE(c2.ok()); + ASSERT_EQ(hpb::interop::upb::GetMessage(c1.value()), + hpb::interop::upb::GetMessage(c2.value())); +} + } // namespace