Fix SetExtension value semantics.

Move SetExtension common code out of header to protos.cc.

PiperOrigin-RevId: 561829364
pull/13817/head
Protobuf Team Bot 1 year ago committed by Copybara-Service
parent 12d4f418a7
commit cb8da95e91
  1. 2
      upb/protos/BUILD
  2. 36
      upb/protos/protos.cc
  3. 62
      upb/protos/protos.h
  4. 9
      upb/protos_generator/tests/BUILD
  5. 50
      upb/protos_generator/tests/test_generated.cc

@ -79,8 +79,10 @@ cc_library(
"//:message_copy",
"//:message_internal",
"//:message_promote",
"//:message_types",
"//:mini_table",
"//:wire",
"//:wire_types",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings",

@ -42,11 +42,13 @@
#include "upb/message/copy.h"
#include "upb/message/internal/extension.h"
#include "upb/message/promote.h"
#include "upb/message/types.h"
#include "upb/mini_table/extension.h"
#include "upb/mini_table/extension_registry.h"
#include "upb/mini_table/message.h"
#include "upb/wire/decode.h"
#include "upb/wire/encode.h"
#include "upb/wire/types.h"
namespace protos {
@ -182,6 +184,40 @@ upb_Message* DeepClone(const upb_Message* source,
return upb_Message_DeepClone(source, mini_table, arena);
}
absl::Status MoveExtension(upb_Message* message, upb_Arena* message_arena,
const upb_MiniTableExtension* ext,
upb_Message* extension, upb_Arena* extension_arena) {
upb_Message_Extension* msg_ext =
_upb_Message_GetOrCreateExtension(message, ext, message_arena);
if (!msg_ext) {
return MessageAllocationError();
}
if (message_arena != extension_arena) {
// Try fuse, if fusing is not allowed or fails, create copy of extension.
if (!upb_Arena_Fuse(message_arena, extension_arena)) {
msg_ext->data.ptr =
DeepClone(extension, msg_ext->ext->sub.submsg, message_arena);
return absl::OkStatus();
}
}
msg_ext->data.ptr = extension;
return absl::OkStatus();
}
absl::Status SetExtension(upb_Message* message, upb_Arena* message_arena,
const upb_MiniTableExtension* ext,
const upb_Message* extension) {
upb_Message_Extension* msg_ext =
_upb_Message_GetOrCreateExtension(message, ext, message_arena);
if (!msg_ext) {
return MessageAllocationError();
}
// Clone extension into target message arena.
msg_ext->data.ptr =
DeepClone(extension, msg_ext->ext->sub.submsg, message_arena);
return absl::OkStatus();
}
} // namespace internal
} // namespace protos

@ -249,6 +249,14 @@ void DeepCopy(upb_Message* target, const upb_Message* source,
upb_Message* DeepClone(const upb_Message* source,
const upb_MiniTable* mini_table, upb_Arena* arena);
absl::Status MoveExtension(upb_Message* message, upb_Arena* message_arena,
const upb_MiniTableExtension* ext,
upb_Message* extension, upb_Arena* extension_arena);
absl::Status SetExtension(upb_Message* message, upb_Arena* message_arena,
const upb_MiniTableExtension* ext,
const upb_Message* extension);
} // namespace internal
template <typename T>
@ -370,29 +378,27 @@ template <typename T, typename Extendee, typename Extension,
absl::Status SetExtension(
Ptr<T> message,
const ::protos::internal::ExtensionIdentifier<Extendee, Extension>& id,
Extension& value) {
const Extension& value) {
static_assert(!std::is_const_v<T>);
auto* message_arena = internal::GetArena(message);
upb_Message_Extension* msg_ext = _upb_Message_GetOrCreateExtension(
internal::GetInternalMsg(message), id.mini_table_ext(), message_arena);
if (!msg_ext) {
return MessageAllocationError();
}
auto* extension_arena = internal::GetArena(&value);
if (message_arena != extension_arena) {
if (!upb_Arena_Fuse(message_arena, extension_arena)) {
// We have to undo the Create part. Otherwise ,we end up with a broken
// extension. We do fuse last because we can undo Create, but we can't
// undo Fuse.
if (msg_ext->data.ptr == nullptr) {
_upb_Message_ClearExtensionField(internal::GetInternalMsg(message),
id.mini_table_ext());
}
return absl::InvalidArgumentError("Unable to fuse arenas.");
}
}
msg_ext->data.ptr = internal::GetInternalMsg(&value);
return absl::OkStatus();
auto* message_arena = static_cast<upb_Arena*>(message->GetInternalArena());
return ::protos::internal::SetExtension(internal::GetInternalMsg(message),
message_arena, id.mini_table_ext(),
internal::GetInternalMsg(&value));
}
template <typename T, typename Extendee, typename Extension,
typename = EnableIfProtosClass<T>, typename = EnableIfMutableProto<T>>
absl::Status SetExtension(
Ptr<T> message,
const ::protos::internal::ExtensionIdentifier<Extendee, Extension>& id,
Extension&& value) {
Extension ext = std::move(value);
static_assert(!std::is_const_v<T>);
auto* message_arena = static_cast<upb_Arena*>(message->GetInternalArena());
auto* extension_arena = static_cast<upb_Arena*>(ext.GetInternalArena());
return ::protos::internal::MoveExtension(
internal::GetInternalMsg(message), message_arena, id.mini_table_ext(),
internal::GetInternalMsg(&ext), extension_arena);
}
template <typename T, typename Extendee, typename Extension,
@ -400,10 +406,20 @@ template <typename T, typename Extendee, typename Extension,
absl::Status SetExtension(
T* message,
const ::protos::internal::ExtensionIdentifier<Extendee, Extension>& id,
Extension& value) {
const Extension& value) {
return ::protos::SetExtension(::protos::Ptr(message), id, value);
}
template <typename T, typename Extendee, typename Extension,
typename = EnableIfProtosClass<T>>
absl::Status SetExtension(
T* message,
const ::protos::internal::ExtensionIdentifier<Extendee, Extension>& id,
Extension&& value) {
return ::protos::SetExtension(::protos::Ptr(message), id,
std::forward<Extension>(value));
}
template <typename T, typename Extendee, typename Extension,
typename = EnableIfProtosClass<T>>
absl::StatusOr<Ptr<const Extension>> GetExtension(

@ -27,14 +27,14 @@ load(
"//bazel:build_defs.bzl",
"UPB_DEFAULT_CPPOPTS",
)
load(
"//protos/bazel:upb_cc_proto_library.bzl",
"upb_cc_proto_library",
)
load(
"//bazel:upb_proto_library.bzl",
"upb_proto_library",
)
load(
"//protos/bazel:upb_cc_proto_library.bzl",
"upb_cc_proto_library",
)
load(
"@rules_cc//cc:defs.bzl",
"cc_proto_library",
@ -152,6 +152,7 @@ cc_test(
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings",
"//protos",
"//:mem",
"//protos:repeated_field",
],
)

@ -44,6 +44,7 @@
#include "protos_generator/tests/child_model.upb.proto.h"
#include "protos_generator/tests/no_package.upb.proto.h"
#include "protos_generator/tests/test_model.upb.proto.h"
#include "upb/mem/arena.h"
using ::protos_generator::test::protos::ChildModel1;
using ::protos_generator::test::protos::other_ext;
@ -683,20 +684,24 @@ TEST(CppGeneratedCode, ClearExtensionWithEmptyExtensionPtr) {
TEST(CppGeneratedCode, SetExtension) {
TestModel model;
void* prior_message;
{
// Use a nested scope to make sure the arenas are fused correctly.
ThemeExtension extension1;
extension1.set_ext_name("Hello World");
prior_message = ::protos::internal::GetInternalMsg(&extension1);
EXPECT_EQ(false, ::protos::HasExtension(&model, theme));
EXPECT_EQ(true, ::protos::SetExtension(&model, theme, extension1).ok());
EXPECT_EQ(
true,
::protos::SetExtension(&model, theme, std::move(extension1)).ok());
}
EXPECT_EQ(true, ::protos::HasExtension(&model, theme));
auto ext = ::protos::GetExtension(&model, theme);
EXPECT_TRUE(ext.ok());
EXPECT_EQ((*ext)->ext_name(), "Hello World");
EXPECT_EQ(::protos::internal::GetInternalMsg(*ext), prior_message);
}
TEST(CppGeneratedCode, SetExtensionFailsFusing) {
TEST(CppGeneratedCode, SetExtensionFusingFailureShouldCopy) {
// Use an initial block to disallow fusing.
char initial_block[1000];
protos::Arena arena(initial_block, sizeof(initial_block));
@ -705,12 +710,41 @@ TEST(CppGeneratedCode, SetExtensionFailsFusing) {
ThemeExtension extension1;
extension1.set_ext_name("Hello World");
ASSERT_FALSE(
upb_Arena_Fuse(arena.ptr(), ::protos::internal::GetArena(&extension1)));
EXPECT_FALSE(::protos::HasExtension(model, theme));
auto status = ::protos::SetExtension(model, theme, extension1);
EXPECT_FALSE(status.ok());
EXPECT_THAT(status.message(), HasSubstr("Unable to fuse arenas."));
EXPECT_FALSE(::protos::HasExtension(model, theme));
EXPECT_FALSE(::protos::GetExtension(model, theme).ok());
auto status = ::protos::SetExtension(model, theme, std::move(extension1));
EXPECT_TRUE(status.ok());
EXPECT_TRUE(::protos::HasExtension(model, theme));
EXPECT_TRUE(::protos::GetExtension(model, theme).ok());
}
TEST(CppGeneratedCode, SetExtensionShouldClone) {
TestModel model;
ThemeExtension extension1;
extension1.set_ext_name("Hello World");
EXPECT_EQ(false, ::protos::HasExtension(&model, theme));
EXPECT_EQ(true, ::protos::SetExtension(&model, theme, extension1).ok());
extension1.set_ext_name("Goodbye");
EXPECT_EQ(true, ::protos::HasExtension(&model, theme));
auto ext = ::protos::GetExtension(&model, theme);
EXPECT_TRUE(ext.ok());
EXPECT_EQ((*ext)->ext_name(), "Hello World");
}
TEST(CppGeneratedCode, SetExtensionShouldCloneConst) {
TestModel model;
ThemeExtension extension1;
extension1.set_ext_name("Hello World");
EXPECT_EQ(false, ::protos::HasExtension(&model, theme));
EXPECT_EQ(
true,
::protos::SetExtension(&model, theme, std::as_const(extension1)).ok());
extension1.set_ext_name("Goodbye");
EXPECT_EQ(true, ::protos::HasExtension(&model, theme));
auto ext = ::protos::GetExtension(&model, theme);
EXPECT_TRUE(ext.ok());
EXPECT_EQ((*ext)->ext_name(), "Hello World");
}
TEST(CppGeneratedCode, SetExtensionOnMutableChild) {

Loading…
Cancel
Save