diff --git a/hpb_generator/gen_repeated_fields.cc b/hpb_generator/gen_repeated_fields.cc index bf67fe7838..0a153190ef 100644 --- a/hpb_generator/gen_repeated_fields.cc +++ b/hpb_generator/gen_repeated_fields.cc @@ -43,6 +43,7 @@ void WriteRepeatedFieldUsingAccessors(const protobuf::FieldDescriptor* field, output( R"cc( using $0Access::add_$1; + using $0Access::add_alias_$1; using $0Access::mutable_$1; )cc", class_name, resolved_field_name); @@ -89,6 +90,12 @@ void WriteRepeatedFieldsInMessageHeader(const protobuf::Descriptor* desc, const ::hpb::RepeatedField::CProxy $2() const; ::hpb::Ptr<::hpb::RepeatedField<$4>> mutable_$2(); absl::StatusOr<$0> add_$2(); + /** + * Re-points submsg of repeated field to given target. + * + * REQUIRES: both messages must be in the same arena. + */ + bool add_alias_$2($0 target); $0 mutable_$2(size_t index) const; )cc", MessagePtrConstType(field, /* const */ false), // $0 @@ -149,13 +156,26 @@ void WriteRepeatedMessageAccessor(const protobuf::Descriptor* message, if (!new_msg) { return ::hpb::MessageAllocationError(); } - return hpb::interop::upb::MakeHandle<$4>((upb_Message*)new_msg, $5); + return hpb::interop::upb::MakeHandle<$4>((upb_Message *)new_msg, $5); + } + + bool $0::add_alias_$2($1 target) { + ABSL_CHECK_EQ(arena_, hpb::interop::upb::GetArena(target)); + size_t size = 0; + $3_$2(msg_, &size); + auto elements = $3_resize_$2(msg_, size + 1, arena_); + if (!elements) { + return false; + } + elements[size] = ($9 *)hpb::interop::upb::GetMessage(target); + return true; } )cc", class_name, MessagePtrConstType(field, /* const */ false), resolved_field_name, MessageName(message), MessageBaseType(field, /* maybe_const */ false), arena_expression, - upbc_name); + upbc_name, ClassName(message), field->index(), + upb::generator::CApiMessageType(field->message_type()->full_name())); output( R"cc( $1 $0::mutable_$2(size_t index) const { diff --git a/hpb_generator/tests/set_alias.proto b/hpb_generator/tests/set_alias.proto index b55923fa60..4b53fb1a8a 100644 --- a/hpb_generator/tests/set_alias.proto +++ b/hpb_generator/tests/set_alias.proto @@ -17,3 +17,7 @@ message Parent { optional int32 x = 1; optional Child child = 2; } + +message ParentWithRepeated { + repeated Child children = 1; +} diff --git a/hpb_generator/tests/test_generated.cc b/hpb_generator/tests/test_generated.cc index 4619b887a2..d17904137f 100644 --- a/hpb_generator/tests/test_generated.cc +++ b/hpb_generator/tests/test_generated.cc @@ -41,6 +41,7 @@ using ::hpb_unittest::protos::container_ext; using ::hpb_unittest::protos::ContainerExtension; using ::hpb_unittest::protos::other_ext; using ::hpb_unittest::protos::Parent; +using ::hpb_unittest::protos::ParentWithRepeated; using ::hpb_unittest::protos::RED; using ::hpb_unittest::protos::TestEnum; using ::hpb_unittest::protos::TestModel; @@ -1273,4 +1274,29 @@ TEST(CppGeneratedCode, SetAliasFailsForDifferentArena) { EXPECT_DEATH(parent.set_alias_child(child), "hpb::interop::upb::GetArena"); } +TEST(CppGeneratedCode, SetAliasRepeated) { + hpb::Arena arena; + auto child = hpb::CreateMessage(arena); + child.set_peeps(1611); + auto parent1 = hpb::CreateMessage(arena); + auto parent2 = hpb::CreateMessage(arena); + parent1.add_alias_children(child); + parent2.add_alias_children(child); + + ASSERT_EQ(parent1.children(0)->peeps(), parent2.children(0)->peeps()); + ASSERT_EQ(hpb::interop::upb::GetMessage(parent1.children(0)), + hpb::interop::upb::GetMessage(parent2.children(0))); + auto childPtr = hpb::Ptr(child); + ASSERT_EQ(hpb::interop::upb::GetMessage(childPtr), + hpb::interop::upb::GetMessage(parent1.children(0))); +} + +TEST(CppGeneratedCode, SetAliasRepeatedFailsForDifferentArena) { + hpb::Arena arena; + auto child = hpb::CreateMessage(arena); + hpb::Arena different_arena; + auto parent = hpb::CreateMessage(different_arena); + EXPECT_DEATH(parent.add_alias_children(child), "hpb::interop::upb::GetArena"); +} + } // namespace