diff --git a/protos/protos.h b/protos/protos.h index 0bac4b683e..87547f9008 100644 --- a/protos/protos.h +++ b/protos/protos.h @@ -340,29 +340,28 @@ ABSL_MUST_USE_RESULT bool HasExtension( return HasExtension(protos::Ptr(message), id); } -template , typename = EnableIfMutableProto> +template , + typename = EnableIfMutableProto> void ClearExtension( Ptr message, - const ::protos::internal::ExtensionIdentifier& id) { + const ::protos::internal::ExtensionIdentifier& id) { static_assert(!std::is_const_v, ""); upb_Message_ClearExtension(internal::GetInternalMsg(message), id.mini_table_ext()); } -template > +template > void ClearExtension( T* message, - const ::protos::internal::ExtensionIdentifier& id) { + const ::protos::internal::ExtensionIdentifier& id) { ClearExtension(::protos::Ptr(message), id); } -template , typename = EnableIfMutableProto> +template , + typename = EnableIfMutableProto> absl::Status SetExtension( Ptr message, - const ::protos::internal::ExtensionIdentifier& id, + const ::protos::internal::ExtensionIdentifier& id, const Extension& value) { static_assert(!std::is_const_v); auto* message_arena = static_cast(message->GetInternalArena()); @@ -371,11 +370,24 @@ absl::Status SetExtension( internal::GetInternalMsg(&value)); } -template , typename = EnableIfMutableProto> +template , + typename = EnableIfMutableProto> absl::Status SetExtension( Ptr message, - const ::protos::internal::ExtensionIdentifier& id, + const ::protos::internal::ExtensionIdentifier& id, + Ptr value) { + static_assert(!std::is_const_v); + auto* message_arena = static_cast(message->GetInternalArena()); + return ::protos::internal::SetExtension(internal::GetInternalMsg(message), + message_arena, id.mini_table_ext(), + internal::GetInternalMsg(value)); +} + +template , + typename = EnableIfMutableProto> +absl::Status SetExtension( + Ptr message, + const ::protos::internal::ExtensionIdentifier& id, Extension&& value) { Extension ext = std::move(value); static_assert(!std::is_const_v); @@ -386,25 +398,28 @@ absl::Status SetExtension( internal::GetInternalMsg(&ext), extension_arena); } -template > +template > absl::Status SetExtension( - T* message, - const ::protos::internal::ExtensionIdentifier& id, + T* message, const ::protos::internal::ExtensionIdentifier& id, const Extension& value) { return ::protos::SetExtension(::protos::Ptr(message), id, value); } -template > +template > absl::Status SetExtension( - T* message, - const ::protos::internal::ExtensionIdentifier& id, + T* message, const ::protos::internal::ExtensionIdentifier& id, Extension&& value) { return ::protos::SetExtension(::protos::Ptr(message), id, std::forward(value)); } +template > +absl::Status SetExtension( + T* message, const ::protos::internal::ExtensionIdentifier& id, + Ptr value) { + return ::protos::SetExtension(::protos::Ptr(message), id, value); +} + template > absl::StatusOr> GetExtension( diff --git a/protos_generator/tests/test_generated.cc b/protos_generator/tests/test_generated.cc index d745eb01cf..208a1be80a 100644 --- a/protos_generator/tests/test_generated.cc +++ b/protos_generator/tests/test_generated.cc @@ -5,6 +5,7 @@ // license that can be found in the LICENSE file or at // https://developers.google.com/open-source/licenses/bsd +#include #include #include #include @@ -24,10 +25,13 @@ #include "protos_generator/tests/no_package.upb.proto.h" #include "protos_generator/tests/test_model.upb.proto.h" #include "upb/mem/arena.h" +#include "upb/mem/arena.hpp" namespace { using ::protos_generator::test::protos::ChildModel1; +using ::protos_generator::test::protos::container_ext; +using ::protos_generator::test::protos::ContainerExtension; using ::protos_generator::test::protos::other_ext; using ::protos_generator::test::protos::RED; using ::protos_generator::test::protos::TestEnum; @@ -39,7 +43,6 @@ using ::protos_generator::test::protos::TestModel_Category_VIDEO; using ::protos_generator::test::protos::theme; using ::protos_generator::test::protos::ThemeExtension; using ::testing::ElementsAre; -using ::testing::HasSubstr; // C++17 port of C++20 `requires` template @@ -442,7 +445,7 @@ TEST(CppGeneratedCode, RepeatedScalarIterator) { EXPECT_EQ(sum, 5 + 16 + 27); // Access by const reference. sum = 0; - for (const int& i : *test_model.mutable_value_array()) { + for (const auto& i : *test_model.mutable_value_array()) { sum += i; } EXPECT_EQ(sum, 5 + 16 + 27); @@ -551,7 +554,7 @@ TEST(CppGeneratedCode, RepeatedFieldProxyForMessages) { } i = 0; - for (auto child : *test_model.mutable_child_models()) { + for (const auto& child : *test_model.mutable_child_models()) { if (i++ == 0) { EXPECT_EQ(child.child_str1(), kTestStr1); } else { @@ -725,6 +728,70 @@ TEST(CppGeneratedCode, SetExtension) { EXPECT_EQ(::protos::internal::GetInternalMsg(*ext), prior_message); } +TEST(CppGeneratedCode, SetExtensionWithPtr) { + ::protos::Arena arena_model; + ::protos::Ptr model = + ::protos::CreateMessage(arena_model); + void* prior_message; + { + // Use a nested scope to make sure the arenas are fused correctly. + ::protos::Arena arena; + ::protos::Ptr extension1 = + ::protos::CreateMessage(arena); + extension1->set_ext_name("Hello World"); + prior_message = ::protos::internal::GetInternalMsg(extension1); + EXPECT_EQ(false, ::protos::HasExtension(model, theme)); + auto res = ::protos::SetExtension(model, theme, extension1); + EXPECT_EQ(true, res.ok()); + } + EXPECT_EQ(true, ::protos::HasExtension(model, theme)); + auto ext = ::protos::GetExtension(model, theme); + EXPECT_TRUE(ext.ok()); + EXPECT_NE(::protos::internal::GetInternalMsg(*ext), prior_message); +} + +#ifndef _MSC_VER +TEST(CppGeneratedCode, SetExtensionShouldNotCompileForWrongType) { + ::protos::Arena arena; + ::protos::Ptr model = ::protos::CreateMessage(arena); + ThemeExtension extension1; + ContainerExtension extension2; + + const auto canSetExtension = [&](auto l) { + return Requires(l); + }; + EXPECT_TRUE(canSetExtension( + [](auto p) -> decltype(::protos::SetExtension(p, theme, extension1)) {})); + // Wrong extension value type should fail to compile. + EXPECT_TRUE(!canSetExtension( + [](auto p) -> decltype(::protos::SetExtension(p, theme, extension2)) {})); + // Wrong extension id with correct extension type should fail to compile. + EXPECT_TRUE( + !canSetExtension([](auto p) -> decltype(::protos::SetExtension( + p, container_ext, extension1)) {})); +} +#endif + +TEST(CppGeneratedCode, SetExtensionWithPtrSameArena) { + ::protos::Arena arena; + ::protos::Ptr model = ::protos::CreateMessage(arena); + void* prior_message; + { + // Use a nested scope to make sure the arenas are fused correctly. + ::protos::Ptr extension1 = + ::protos::CreateMessage(arena); + extension1->set_ext_name("Hello World"); + prior_message = ::protos::internal::GetInternalMsg(extension1); + EXPECT_EQ(false, ::protos::HasExtension(model, theme)); + auto res = ::protos::SetExtension(model, theme, extension1); + EXPECT_EQ(true, res.ok()); + } + EXPECT_EQ(true, ::protos::HasExtension(model, theme)); + auto ext = ::protos::GetExtension(model, theme); + EXPECT_TRUE(ext.ok()); + EXPECT_NE(::protos::internal::GetInternalMsg(*ext), prior_message); +} + TEST(CppGeneratedCode, SetExtensionFusingFailureShouldCopy) { // Use an initial block to disallow fusing. char initial_block[1000]; diff --git a/protos_generator/tests/test_model.proto b/protos_generator/tests/test_model.proto index 24b4406e95..34c875aed4 100644 --- a/protos_generator/tests/test_model.proto +++ b/protos_generator/tests/test_model.proto @@ -14,6 +14,8 @@ import "protos_generator/tests/child_model.proto"; message TestModelContainer { repeated TestModel models = 1; optional ChildModel3 proto_3_child = 2; + extensions 10000 to max + [verification = UNVERIFIED]; } message TestModel { @@ -138,6 +140,17 @@ extend TestModel { optional ThemeExtension theme = 12001; } +message ContainerExtension { + extend TestModelContainer { + optional ContainerExtension container_extension = 12004; + } + optional string ext_container_name = 1; +} + +extend TestModelContainer { + optional ContainerExtension container_ext = 12005; +} + message OtherExtension { optional string ext2_name = 1; }