Add support for setting extensions with Ptr<Extension> in Upb C++ protos.

PiperOrigin-RevId: 633646216
pull/16740/head
Protobuf Team Bot 8 months ago committed by Copybara-Service
parent 448e326200
commit 396d661767
  1. 55
      protos/protos.h
  2. 73
      protos_generator/tests/test_generated.cc
  3. 13
      protos_generator/tests/test_model.proto

@ -340,29 +340,28 @@ ABSL_MUST_USE_RESULT bool HasExtension(
return HasExtension(protos::Ptr(message), id);
}
template <typename T, typename Extendee, typename Extension,
typename = EnableIfProtosClass<T>, typename = EnableIfMutableProto<T>>
template <typename T, typename Extension, typename = EnableIfProtosClass<T>,
typename = EnableIfMutableProto<T>>
void ClearExtension(
Ptr<T> message,
const ::protos::internal::ExtensionIdentifier<Extendee, Extension>& id) {
const ::protos::internal::ExtensionIdentifier<T, Extension>& id) {
static_assert(!std::is_const_v<T>, "");
upb_Message_ClearExtension(internal::GetInternalMsg(message),
id.mini_table_ext());
}
template <typename T, typename Extendee, typename Extension,
typename = EnableIfProtosClass<T>>
template <typename T, typename Extension, typename = EnableIfProtosClass<T>>
void ClearExtension(
T* message,
const ::protos::internal::ExtensionIdentifier<Extendee, Extension>& id) {
const ::protos::internal::ExtensionIdentifier<T, Extension>& id) {
ClearExtension(::protos::Ptr(message), id);
}
template <typename T, typename Extendee, typename Extension,
typename = EnableIfProtosClass<T>, typename = EnableIfMutableProto<T>>
template <typename T, typename Extension, typename = EnableIfProtosClass<T>,
typename = EnableIfMutableProto<T>>
absl::Status SetExtension(
Ptr<T> message,
const ::protos::internal::ExtensionIdentifier<Extendee, Extension>& id,
const ::protos::internal::ExtensionIdentifier<T, Extension>& id,
const Extension& value) {
static_assert(!std::is_const_v<T>);
auto* message_arena = static_cast<upb_Arena*>(message->GetInternalArena());
@ -371,11 +370,24 @@ absl::Status SetExtension(
internal::GetInternalMsg(&value));
}
template <typename T, typename Extendee, typename Extension,
typename = EnableIfProtosClass<T>, typename = EnableIfMutableProto<T>>
template <typename T, typename Extension, typename = EnableIfProtosClass<T>,
typename = EnableIfMutableProto<T>>
absl::Status SetExtension(
Ptr<T> message,
const ::protos::internal::ExtensionIdentifier<Extendee, Extension>& id,
const ::protos::internal::ExtensionIdentifier<T, Extension>& id,
Ptr<Extension> value) {
static_assert(!std::is_const_v<T>);
auto* message_arena = static_cast<upb_Arena*>(message->GetInternalArena());
return ::protos::internal::SetExtension(internal::GetInternalMsg(message),
message_arena, id.mini_table_ext(),
internal::GetInternalMsg(value));
}
template <typename T, typename Extension, typename = EnableIfProtosClass<T>,
typename = EnableIfMutableProto<T>>
absl::Status SetExtension(
Ptr<T> message,
const ::protos::internal::ExtensionIdentifier<T, Extension>& id,
Extension&& value) {
Extension ext = std::move(value);
static_assert(!std::is_const_v<T>);
@ -386,25 +398,28 @@ absl::Status SetExtension(
internal::GetInternalMsg(&ext), extension_arena);
}
template <typename T, typename Extendee, typename Extension,
typename = EnableIfProtosClass<T>>
template <typename T, typename Extension, typename = EnableIfProtosClass<T>>
absl::Status SetExtension(
T* message,
const ::protos::internal::ExtensionIdentifier<Extendee, Extension>& id,
T* message, const ::protos::internal::ExtensionIdentifier<T, Extension>& id,
const Extension& value) {
return ::protos::SetExtension(::protos::Ptr(message), id, value);
}
template <typename T, typename Extendee, typename Extension,
typename = EnableIfProtosClass<T>>
template <typename T, typename Extension, typename = EnableIfProtosClass<T>>
absl::Status SetExtension(
T* message,
const ::protos::internal::ExtensionIdentifier<Extendee, Extension>& id,
T* message, const ::protos::internal::ExtensionIdentifier<T, Extension>& id,
Extension&& value) {
return ::protos::SetExtension(::protos::Ptr(message), id,
std::forward<Extension>(value));
}
template <typename T, typename Extension, typename = EnableIfProtosClass<T>>
absl::Status SetExtension(
T* message, const ::protos::internal::ExtensionIdentifier<T, Extension>& id,
Ptr<Extension> value) {
return ::protos::SetExtension(::protos::Ptr(message), id, value);
}
template <typename T, typename Extendee, typename Extension,
typename = EnableIfProtosClass<T>>
absl::StatusOr<Ptr<const Extension>> GetExtension(

@ -5,6 +5,7 @@
// license that can be found in the LICENSE file or at
// https://developers.google.com/open-source/licenses/bsd
#include <cstdint>
#include <iterator>
#include <limits>
#include <memory>
@ -24,10 +25,13 @@
#include "protos_generator/tests/no_package.upb.proto.h"
#include "protos_generator/tests/test_model.upb.proto.h"
#include "upb/mem/arena.h"
#include "upb/mem/arena.hpp"
namespace {
using ::protos_generator::test::protos::ChildModel1;
using ::protos_generator::test::protos::container_ext;
using ::protos_generator::test::protos::ContainerExtension;
using ::protos_generator::test::protos::other_ext;
using ::protos_generator::test::protos::RED;
using ::protos_generator::test::protos::TestEnum;
@ -39,7 +43,6 @@ using ::protos_generator::test::protos::TestModel_Category_VIDEO;
using ::protos_generator::test::protos::theme;
using ::protos_generator::test::protos::ThemeExtension;
using ::testing::ElementsAre;
using ::testing::HasSubstr;
// C++17 port of C++20 `requires`
template <typename... T, typename F>
@ -442,7 +445,7 @@ TEST(CppGeneratedCode, RepeatedScalarIterator) {
EXPECT_EQ(sum, 5 + 16 + 27);
// Access by const reference.
sum = 0;
for (const int& i : *test_model.mutable_value_array()) {
for (const auto& i : *test_model.mutable_value_array()) {
sum += i;
}
EXPECT_EQ(sum, 5 + 16 + 27);
@ -551,7 +554,7 @@ TEST(CppGeneratedCode, RepeatedFieldProxyForMessages) {
}
i = 0;
for (auto child : *test_model.mutable_child_models()) {
for (const auto& child : *test_model.mutable_child_models()) {
if (i++ == 0) {
EXPECT_EQ(child.child_str1(), kTestStr1);
} else {
@ -725,6 +728,70 @@ TEST(CppGeneratedCode, SetExtension) {
EXPECT_EQ(::protos::internal::GetInternalMsg(*ext), prior_message);
}
TEST(CppGeneratedCode, SetExtensionWithPtr) {
::protos::Arena arena_model;
::protos::Ptr<TestModel> model =
::protos::CreateMessage<TestModel>(arena_model);
void* prior_message;
{
// Use a nested scope to make sure the arenas are fused correctly.
::protos::Arena arena;
::protos::Ptr<ThemeExtension> extension1 =
::protos::CreateMessage<ThemeExtension>(arena);
extension1->set_ext_name("Hello World");
prior_message = ::protos::internal::GetInternalMsg(extension1);
EXPECT_EQ(false, ::protos::HasExtension(model, theme));
auto res = ::protos::SetExtension(model, theme, extension1);
EXPECT_EQ(true, res.ok());
}
EXPECT_EQ(true, ::protos::HasExtension(model, theme));
auto ext = ::protos::GetExtension(model, theme);
EXPECT_TRUE(ext.ok());
EXPECT_NE(::protos::internal::GetInternalMsg(*ext), prior_message);
}
#ifndef _MSC_VER
TEST(CppGeneratedCode, SetExtensionShouldNotCompileForWrongType) {
::protos::Arena arena;
::protos::Ptr<TestModel> model = ::protos::CreateMessage<TestModel>(arena);
ThemeExtension extension1;
ContainerExtension extension2;
const auto canSetExtension = [&](auto l) {
return Requires<decltype(model)>(l);
};
EXPECT_TRUE(canSetExtension(
[](auto p) -> decltype(::protos::SetExtension(p, theme, extension1)) {}));
// Wrong extension value type should fail to compile.
EXPECT_TRUE(!canSetExtension(
[](auto p) -> decltype(::protos::SetExtension(p, theme, extension2)) {}));
// Wrong extension id with correct extension type should fail to compile.
EXPECT_TRUE(
!canSetExtension([](auto p) -> decltype(::protos::SetExtension(
p, container_ext, extension1)) {}));
}
#endif
TEST(CppGeneratedCode, SetExtensionWithPtrSameArena) {
::protos::Arena arena;
::protos::Ptr<TestModel> model = ::protos::CreateMessage<TestModel>(arena);
void* prior_message;
{
// Use a nested scope to make sure the arenas are fused correctly.
::protos::Ptr<ThemeExtension> extension1 =
::protos::CreateMessage<ThemeExtension>(arena);
extension1->set_ext_name("Hello World");
prior_message = ::protos::internal::GetInternalMsg(extension1);
EXPECT_EQ(false, ::protos::HasExtension(model, theme));
auto res = ::protos::SetExtension(model, theme, extension1);
EXPECT_EQ(true, res.ok());
}
EXPECT_EQ(true, ::protos::HasExtension(model, theme));
auto ext = ::protos::GetExtension(model, theme);
EXPECT_TRUE(ext.ok());
EXPECT_NE(::protos::internal::GetInternalMsg(*ext), prior_message);
}
TEST(CppGeneratedCode, SetExtensionFusingFailureShouldCopy) {
// Use an initial block to disallow fusing.
char initial_block[1000];

@ -14,6 +14,8 @@ import "protos_generator/tests/child_model.proto";
message TestModelContainer {
repeated TestModel models = 1;
optional ChildModel3 proto_3_child = 2;
extensions 10000 to max
[verification = UNVERIFIED];
}
message TestModel {
@ -138,6 +140,17 @@ extend TestModel {
optional ThemeExtension theme = 12001;
}
message ContainerExtension {
extend TestModelContainer {
optional ContainerExtension container_extension = 12004;
}
optional string ext_container_name = 1;
}
extend TestModelContainer {
optional ContainerExtension container_ext = 12005;
}
message OtherExtension {
optional string ext2_name = 1;
}

Loading…
Cancel
Save