Introduce new API DownCastToMessage/DynamicCastToMessage as a replacement for down_cast/dynamic_cast from MessageLite to Message.

The new functions work even when RTTI is not present.

Also, add an assertion in DownCast to make sure we don't use it for message types. Message types have dedicated cast functions that should be used instead.

PiperOrigin-RevId: 633236522
pull/16819/head
Protobuf Team Bot 7 months ago committed by Copybara-Service
parent fad4736785
commit 7f23b700fe
  1. 3
      src/google/protobuf/compiler/cpp/message.cc
  2. 4
      src/google/protobuf/compiler/cpp/service.cc
  3. 4
      src/google/protobuf/compiler/cpp/unittest.inc
  4. 8
      src/google/protobuf/compiler/parser.cc
  5. 2
      src/google/protobuf/descriptor.cc
  6. 2
      src/google/protobuf/extension_set_heavy.cc
  7. 1
      src/google/protobuf/extension_set_unittest.cc
  8. 5
      src/google/protobuf/generated_message_tctable_full.cc
  9. 1
      src/google/protobuf/map_field.cc
  10. 30
      src/google/protobuf/map_test.inc
  11. 15
      src/google/protobuf/message.cc
  12. 41
      src/google/protobuf/message.h
  13. 5
      src/google/protobuf/message_lite.h
  14. 32
      src/google/protobuf/port.h
  15. 10
      src/google/protobuf/reflection_visit_field_info.h
  16. 4
      src/google/protobuf/reflection_visit_fields.h
  17. 13
      src/google/protobuf/repeated_field_reflection_unittest.inc

@ -3747,8 +3747,7 @@ void MessageGenerator::GenerateMergeFrom(io::Printer* p) {
format(
"void $classname$::CheckTypeAndMergeFrom(\n"
" const ::$proto_ns$::MessageLite& from) {\n"
" MergeFrom(*::_pbi::DownCast<const $classname$*>(\n"
" &from));\n"
" MergeFrom(::$proto_ns$::DownCastToGenerated<$classname$>(from));\n"
"}\n");
}
}

@ -256,8 +256,8 @@ void ServiceGenerator::GenerateCallMethodCases(io::Printer* printer) {
R"cc(
case $index$:
$name$(controller,
::$proto_ns$::internal::DownCast<const $input$*>(request),
::$proto_ns$::internal::DownCast<$output$*>(response), done);
::$proto_ns$::DownCastToGenerated<$input$>(request),
::$proto_ns$::DownCastToGenerated<$output$>(response), done);
break;
)cc");
}

@ -1314,13 +1314,13 @@ TEST_F(GENERATED_SERVICE_TEST_NAME, CallMethodTypeFailure) {
EXPECT_DEBUG_DEATH(
mock_service_.CallMethod(foo_, &mock_controller_,
&foo_request_, &bar_response_, done_.get()),
"dynamic_cast");
"DynamicCastToGenerated");
mock_service_.Reset();
EXPECT_DEBUG_DEATH(
mock_service_.CallMethod(foo_, &mock_controller_,
&bar_request_, &foo_response_, done_.get()),
"dynamic_cast");
"DynamicCastToGenerated");
#endif // GTEST_HAS_DEATH_TEST
}

@ -38,6 +38,7 @@
#include "google/protobuf/descriptor.pb.h"
#include "google/protobuf/io/strtod.h"
#include "google/protobuf/io/tokenizer.h"
#include "google/protobuf/message_lite.h"
#include "google/protobuf/port.h"
#include "google/protobuf/wire_format.h"
@ -49,8 +50,6 @@ namespace protobuf {
namespace compiler {
namespace {
using ::google::protobuf::internal::DownCast;
using TypeNameMap =
absl::flat_hash_map<absl::string_view, FieldDescriptorProto::Type>;
@ -1566,8 +1565,9 @@ bool Parser::ParseOption(Message* options,
}
UninterpretedOption* uninterpreted_option =
DownCast<UninterpretedOption*>(options->GetReflection()->AddMessage(
options, uninterpreted_option_field));
DownCastToGenerated<UninterpretedOption>(
options->GetReflection()->AddMessage(options,
uninterpreted_option_field));
// Parse dot-separated name.
{

@ -8590,7 +8590,7 @@ bool DescriptorBuilder::OptionInterpreter::InterpretOptionsImpl(
*original_options, original_uninterpreted_options_field);
for (int i = 0; i < num_uninterpreted_options; ++i) {
src_path.push_back(i);
uninterpreted_option_ = DownCast<const UninterpretedOption*>(
uninterpreted_option_ = DownCastToGenerated<UninterpretedOption>(
&original_options->GetReflection()->GetRepeatedMessage(
*original_options, original_uninterpreted_options_field, i));
if (!InterpretSingleOption(options, src_path,

@ -406,7 +406,7 @@ size_t ExtensionSet::Extension::SpaceUsedExcludingSelfLong() const {
if (is_lazy) {
total_size += lazymessage_value->SpaceUsedLong();
} else {
total_size += DownCast<Message*>(message_value)->SpaceUsedLong();
total_size += DownCastToMessage(message_value)->SpaceUsedLong();
}
break;
default:

@ -26,6 +26,7 @@
#include "google/protobuf/dynamic_message.h"
#include "google/protobuf/io/coded_stream.h"
#include "google/protobuf/io/zero_copy_stream_impl.h"
#include "google/protobuf/message_lite.h"
#include "google/protobuf/test_util.h"
#include "google/protobuf/test_util2.h"
#include "google/protobuf/text_format.h"

@ -44,7 +44,6 @@
namespace google {
namespace protobuf {
namespace internal {
using ::google::protobuf::internal::DownCast;
const char* TcParser::GenericFallback(PROTOBUF_TC_PARAM_DECL) {
PROTOBUF_MUSTTAIL return GenericFallbackImpl<Message, UnknownFieldSet>(
@ -64,7 +63,7 @@ const char* TcParser::ReflectionFallback(PROTOBUF_TC_PARAM_DECL) {
return ptr;
}
auto* full_msg = DownCast<Message*>(msg);
auto* full_msg = DownCastToMessage(msg);
auto* descriptor = full_msg->GetDescriptor();
auto* reflection = full_msg->GetReflection();
int field_number = WireFormatLite::GetTagFieldNumber(tag);
@ -88,7 +87,7 @@ const char* TcParser::ReflectionParseLoop(PROTOBUF_TC_PARAM_DECL) {
(void)table;
(void)hasbits;
// Call into the wire format reflective parse loop.
return WireFormat::_InternalParse(DownCast<Message*>(msg), ptr, ctx);
return WireFormat::_InternalParse(DownCastToMessage(msg), ptr, ctx);
}
const char* TcParser::MessageSetWireFormatParseLoop(

@ -21,7 +21,6 @@
namespace google {
namespace protobuf {
namespace internal {
using ::google::protobuf::internal::DownCast;
VariantKey RealKeyToVariantKey<MapKey>::operator()(const MapKey& value) const {
switch (value.type()) {

@ -1578,7 +1578,7 @@ TEST_F(MapFieldReflectionTest, RegularFields) {
message_int32_message.GetReflection()->GetInt32(
message_int32_message, fd_map_int32_foreign_message_key);
const ForeignMessage& value_int32_message =
DownCast<const ForeignMessage&>(
DownCastToGenerated<ForeignMessage>(
message_int32_message.GetReflection()->GetMessage(
message_int32_message, fd_map_int32_foreign_message_value));
EXPECT_EQ(value_int32_message.c(), Func(key_int32_message, 6));
@ -1615,7 +1615,7 @@ TEST_F(MapFieldReflectionTest, RegularFields) {
message_int32_message.GetReflection()->GetInt32(
message_int32_message, fd_map_int32_foreign_message_key);
const ForeignMessage& value_int32_message =
DownCast<const ForeignMessage&>(
DownCastToGenerated<ForeignMessage>(
message_int32_message.GetReflection()->GetMessage(
message_int32_message, fd_map_int32_foreign_message_value));
EXPECT_EQ(value_int32_message.c(), Func(key_int32_message, 6));
@ -1652,7 +1652,7 @@ TEST_F(MapFieldReflectionTest, RegularFields) {
int32_t key_int32_message =
message_int32_message->GetReflection()->GetInt32(
*message_int32_message, fd_map_int32_foreign_message_key);
ForeignMessage* value_int32_message = DownCast<ForeignMessage*>(
ForeignMessage* value_int32_message = DownCastToGenerated<ForeignMessage>(
message_int32_message->GetReflection()->MutableMessage(
message_int32_message, fd_map_int32_foreign_message_value));
value_int32_message->set_c(Func(key_int32_message, -6));
@ -1808,7 +1808,7 @@ TEST_F(MapFieldReflectionTest, RepeatedFieldRefForRegularFields) {
message_int32_message.GetReflection()->GetInt32(
message_int32_message, fd_map_int32_foreign_message_key);
const ForeignMessage& value_int32_message =
DownCast<const ForeignMessage&>(
DownCastToGenerated<ForeignMessage>(
message_int32_message.GetReflection()->GetMessage(
message_int32_message, fd_map_int32_foreign_message_value));
EXPECT_EQ(value_int32_message.c(), Func(key_int32_message, 6));
@ -1849,7 +1849,7 @@ TEST_F(MapFieldReflectionTest, RepeatedFieldRefForRegularFields) {
message_int32_message.GetReflection()->GetInt32(
message_int32_message, fd_map_int32_foreign_message_key);
const ForeignMessage& value_int32_message =
DownCast<const ForeignMessage&>(
DownCastToGenerated<ForeignMessage>(
message_int32_message.GetReflection()->GetMessage(
message_int32_message, fd_map_int32_foreign_message_value));
EXPECT_EQ(value_int32_message.c(), Func(key_int32_message, 6));
@ -1966,8 +1966,8 @@ TEST_F(MapFieldReflectionTest, RepeatedFieldRefForRegularFields) {
const Message& message = *it;
int32_t key = message.GetReflection()->GetInt32(
message, fd_map_int32_foreign_message_key);
const ForeignMessage& sub_message =
DownCast<const ForeignMessage&>(message.GetReflection()->GetMessage(
const ForeignMessage& sub_message = DownCastToGenerated<ForeignMessage>(
message.GetReflection()->GetMessage(
message, fd_map_int32_foreign_message_value));
result[key].MergeFrom(sub_message);
++index;
@ -2120,14 +2120,14 @@ TEST_F(MapFieldReflectionTest, RepeatedFieldRefForRegularFields) {
{
const Message& message0a =
mmf_int32_foreign_message.Get(0, entry_int32_foreign_message.get());
const ForeignMessage& sub_message0a =
DownCast<const ForeignMessage&>(message0a.GetReflection()->GetMessage(
const ForeignMessage& sub_message0a = DownCastToGenerated<ForeignMessage>(
message0a.GetReflection()->GetMessage(
message0a, fd_map_int32_foreign_message_value));
int32_t int32_value0a = sub_message0a.c();
const Message& message9a =
mmf_int32_foreign_message.Get(9, entry_int32_foreign_message.get());
const ForeignMessage& sub_message9a =
DownCast<const ForeignMessage&>(message9a.GetReflection()->GetMessage(
const ForeignMessage& sub_message9a = DownCastToGenerated<ForeignMessage>(
message9a.GetReflection()->GetMessage(
message9a, fd_map_int32_foreign_message_value));
int32_t int32_value9a = sub_message9a.c();
@ -2135,14 +2135,14 @@ TEST_F(MapFieldReflectionTest, RepeatedFieldRefForRegularFields) {
const Message& message0b =
mmf_int32_foreign_message.Get(0, entry_int32_foreign_message.get());
const ForeignMessage& sub_message0b =
DownCast<const ForeignMessage&>(message0b.GetReflection()->GetMessage(
const ForeignMessage& sub_message0b = DownCastToGenerated<ForeignMessage>(
message0b.GetReflection()->GetMessage(
message0b, fd_map_int32_foreign_message_value));
int32_t int32_value0b = sub_message0b.c();
const Message& message9b =
mmf_int32_foreign_message.Get(9, entry_int32_foreign_message.get());
const ForeignMessage& sub_message9b =
DownCast<const ForeignMessage&>(message9b.GetReflection()->GetMessage(
const ForeignMessage& sub_message9b = DownCastToGenerated<ForeignMessage>(
message9b.GetReflection()->GetMessage(
message9b, fd_map_int32_foreign_message_value));
int32_t int32_value9b = sub_message9b.c();

@ -63,12 +63,11 @@ void RegisterFileLevelMetadata(const DescriptorTable* descriptor_table);
} // namespace internal
using internal::DownCast;
using internal::ReflectionOps;
using internal::WireFormat;
void Message::MergeImpl(MessageLite& to, const MessageLite& from) {
ReflectionOps::Merge(DownCast<const Message&>(from), DownCast<Message*>(&to));
ReflectionOps::Merge(DownCastToMessage(from), DownCastToMessage(&to));
}
void Message::MergeFrom(const Message& from) {
@ -82,7 +81,7 @@ void Message::MergeFrom(const Message& from) {
}
void Message::CheckTypeAndMergeFrom(const MessageLite& other) {
MergeFrom(*DownCast<const Message*>(&other));
MergeFrom(DownCastToMessage(other));
}
void Message::CopyFrom(const Message& from) {
@ -113,7 +112,7 @@ void Message::CopyFrom(const Message& from) {
void Message::Clear() { ReflectionOps::Clear(this); }
bool Message::IsInitializedImpl(const MessageLite& msg) {
return ReflectionOps::IsInitialized(DownCast<const Message&>(msg));
return ReflectionOps::IsInitialized(DownCastToMessage(msg));
}
void Message::FindInitializationErrors(std::vector<std::string>* errors) const {
@ -189,20 +188,20 @@ size_t Message::SpaceUsedLong() const {
}
static std::string GetTypeNameImpl(const MessageLite& msg) {
return DownCast<const Message&>(msg).GetDescriptor()->full_name();
return DownCastToMessage(msg).GetDescriptor()->full_name();
}
static std::string InitializationErrorStringImpl(const MessageLite& msg) {
return DownCast<const Message&>(msg).InitializationErrorString();
return DownCastToMessage(msg).InitializationErrorString();
}
const internal::TcParseTableBase* Message::GetTcParseTableImpl(
const MessageLite& msg) {
return DownCast<const Message&>(msg).GetReflection()->GetTcParseTable();
return DownCastToMessage(msg).GetReflection()->GetTcParseTable();
}
size_t Message::SpaceUsedLongImpl(const MessageLite& msg_lite) {
auto& msg = DownCast<const Message&>(msg_lite);
auto& msg = DownCastToMessage(msg_lite);
return msg.GetReflection()->SpaceUsedLong(msg);
}

@ -1449,6 +1449,47 @@ void LinkMessageReflection() {
internal::StrongReferenceToType<T>();
}
// Tries to downcast this message from MessageLite to Message. Returns nullptr
// if this class is not an instance of Message. eg if the message was defined
// with optimized_for=LITE_RUNTIME. This works even if RTTI is disabled.
inline const Message* DynamicCastToMessage(const MessageLite* lite) {
return lite == nullptr || internal::GetClassData(*lite)->is_lite
? nullptr
: static_cast<const Message*>(lite);
}
inline Message* DynamicCastToMessage(MessageLite* lite) {
return const_cast<Message*>(
DynamicCastToMessage(static_cast<const MessageLite*>(lite)));
}
inline const Message& DynamicCastToMessage(const MessageLite& lite) {
auto* res = DynamicCastToMessage(&lite);
ABSL_CHECK(res != nullptr)
<< "Cannot to `Message` type " << lite.GetTypeName();
return *res;
}
inline Message& DynamicCastToMessage(MessageLite& lite) {
return const_cast<Message&>(
DynamicCastToMessage(static_cast<const MessageLite&>(lite)));
}
// A lightweight function for downcasting a MessageLite to Message. It should
// only be used when the caller is certain that the argument is a Message
// object.
inline const Message* DownCastToMessage(const MessageLite* lite) {
ABSL_CHECK(lite == nullptr || DynamicCastToMessage(lite) != nullptr);
return static_cast<const Message*>(lite);
}
inline Message* DownCastToMessage(MessageLite* lite) {
return const_cast<Message*>(
DownCastToMessage(static_cast<const MessageLite*>(lite)));
}
inline const Message& DownCastToMessage(const MessageLite& lite) {
return *DownCastToMessage(&lite);
}
inline Message& DownCastToMessage(MessageLite& lite) {
return *DownCastToMessage(&lite);
}
// =============================================================================
// Implementation details for {Get,Mutable}RawRepeatedPtrField. We provide
// specializations for <std::string>, <StringPieceField> and <Message> and

@ -124,6 +124,7 @@ class PROTOBUF_EXPORT CachedSize {
// For MessageLite to friend.
class TypeId;
auto GetClassData(const MessageLite& msg);
class SwapFieldHelper;
@ -705,6 +706,8 @@ class PROTOBUF_EXPORT MessageLite {
template <typename Type>
friend class internal::GenericTypeHandler;
friend auto internal::GetClassData(const MessageLite& msg);
void LogInitializationErrorMessage() const;
bool MergeFromImpl(io::CodedInputStream* input, ParseFlags parse_flags);
@ -758,6 +761,8 @@ class TypeId {
const MessageLite::ClassData* data_;
};
inline auto GetClassData(const MessageLite& msg) { return msg.GetClassData(); }
template <bool alias>
bool MergeFromImpl(absl::string_view input, MessageLite* msg,
const internal::TcParseTableBase* tc_table,

@ -159,28 +159,34 @@ struct ArenaInitialized {
};
template <typename To, typename From>
inline To DownCast(From* f) {
static_assert(
std::is_base_of<From, typename std::remove_pointer<To>::type>::value,
"illegal DownCast");
void AssertDownCast(From* from) {
static_assert(std::is_base_of<From, To>::value, "illegal DownCast");
#if defined(__cpp_concepts)
// Check that this function is not used to downcast message types.
// For those we should use {Down,Dynamic}CastTo{Message,Generated}.
static_assert(!requires {
std::derived_from<std::remove_pointer_t<To>,
typename std::remove_pointer_t<To>::MessageLite>;
});
#endif
#if PROTOBUF_RTTI
// RTTI: debug mode only!
assert(f == nullptr || dynamic_cast<To>(f) != nullptr);
assert(from == nullptr || dynamic_cast<To*>(from) != nullptr);
#endif
}
template <typename To, typename From>
inline To DownCast(From* f) {
AssertDownCast<std::remove_pointer_t<To>>(f);
return static_cast<To>(f);
}
template <typename ToRef, typename From>
inline ToRef DownCast(From& f) {
using To = typename std::remove_reference<ToRef>::type;
static_assert(std::is_base_of<From, To>::value, "illegal DownCast");
#if PROTOBUF_RTTI
// RTTI: debug mode only!
assert(dynamic_cast<To*>(&f) != nullptr);
#endif
return *static_cast<To*>(&f);
AssertDownCast<std::remove_reference_t<ToRef>>(&f);
return static_cast<ToRef>(f);
}
// Looks up the name of `T` via RTTI, if RTTI is available.

@ -200,10 +200,10 @@ struct DynamicExtensionInfoHelper {
}
static const Message& GetMessage(const Extension& ext) {
return DownCast<const Message&>(*ext.message_value);
return DownCastToMessage(*ext.message_value);
}
static Message& MutableMessage(Extension& ext) {
return DownCast<Message&>(*ext.message_value);
return DownCastToMessage(*ext.message_value);
}
static void ClearMessage(Extension& ext) {
ext.is_cleared = true;
@ -212,18 +212,18 @@ struct DynamicExtensionInfoHelper {
static const Message& GetLazyMessage(const Extension& ext,
const Message& prototype, Arena* arena) {
return DownCast<const Message&>(
return DownCastToMessage(
ext.lazymessage_value->GetMessage(prototype, arena));
}
static const Message& GetLazyMessageIgnoreUnparsed(const Extension& ext,
const Message& prototype,
Arena* arena) {
return DownCast<const Message&>(
return DownCastToMessage(
ext.lazymessage_value->GetMessageIgnoreUnparsed(prototype, arena));
}
static Message& MutableLazyMessage(Extension& ext, const Message& prototype,
Arena* arena) {
return DownCast<Message&>(
return DownCastToMessage(
*ext.lazymessage_value->MutableMessage(prototype, arena));
}
static void ClearLazyMessage(Extension& ext) {

@ -422,7 +422,7 @@ void ReflectionVisit::VisitMessageFields(const Message& message,
FieldDescriptor::CPPTYPE_MESSAGE) {
if constexpr (info.is_repeated) {
for (const auto& it : info.Get()) {
func(DownCast<const Message&>(it));
func(DownCastToMessage(it));
}
} else {
func(info.Get());
@ -452,7 +452,7 @@ void ReflectionVisit::VisitMessageFields(Message& message, CallbackFn&& func) {
FieldDescriptor::CPPTYPE_MESSAGE) {
if constexpr (info.is_repeated) {
for (auto& it : info.Mutable()) {
func(DownCast<Message&>(it));
func(DownCastToMessage(it));
}
} else {
func(info.Mutable());

@ -91,7 +91,7 @@ TEST(REFLECTION_TEST, RegularFields) {
EXPECT_EQ(rf_double.Get(i), Func(i, 2));
EXPECT_EQ(rpf_string.Get(i), StrFunc(i, 5));
EXPECT_EQ(rpf_foreign_message.Get(i).c(), Func(i, 6));
EXPECT_EQ(DownCast<const ForeignMessage*>(&rpf_message.Get(i))->c(),
EXPECT_EQ(DownCastToGenerated<ForeignMessage>(&rpf_message.Get(i))->c(),
Func(i, 6));
// Check gets through mutable objects.
@ -99,7 +99,7 @@ TEST(REFLECTION_TEST, RegularFields) {
EXPECT_EQ(mrf_double->Get(i), Func(i, 2));
EXPECT_EQ(mrpf_string->Get(i), StrFunc(i, 5));
EXPECT_EQ(mrpf_foreign_message->Get(i).c(), Func(i, 6));
EXPECT_EQ(DownCast<const ForeignMessage*>(&mrpf_message->Get(i))->c(),
EXPECT_EQ(DownCastToGenerated<ForeignMessage>(&mrpf_message->Get(i))->c(),
Func(i, 6));
// Check sets through mutable objects.
@ -111,7 +111,8 @@ TEST(REFLECTION_TEST, RegularFields) {
EXPECT_EQ(message.repeated_double(i), Func(i, -2));
EXPECT_EQ(message.repeated_string(i), StrFunc(i, -5));
EXPECT_EQ(message.repeated_foreign_message(i).c(), Func(i, -6));
DownCast<ForeignMessage*>(mrpf_message->Mutable(i))->set_c(Func(i, 7));
DownCastToGenerated<ForeignMessage>(mrpf_message->Mutable(i))
->set_c(Func(i, 7));
EXPECT_EQ(message.repeated_foreign_message(i).c(), Func(i, 7));
}
@ -271,7 +272,8 @@ TEST(REFLECTION_TEST, RepeatedFieldRefForRegularFields) {
ForeignMessage scratch_space;
EXPECT_EQ(rf_foreign_message.Get(i, &scratch_space).c(), Func(i, 6));
EXPECT_EQ(
DownCast<const ForeignMessage&>(rf_message.Get(i, &scratch_space)).c(),
DownCastToGenerated<ForeignMessage>(rf_message.Get(i, &scratch_space))
.c(),
Func(i, 6));
// Check gets through mutable objects.
@ -280,7 +282,8 @@ TEST(REFLECTION_TEST, RepeatedFieldRefForRegularFields) {
EXPECT_EQ(mrf_string.Get(i), StrFunc(i, 5));
EXPECT_EQ(mrf_foreign_message.Get(i, &scratch_space).c(), Func(i, 6));
EXPECT_EQ(
DownCast<const ForeignMessage&>(mrf_message.Get(i, &scratch_space)).c(),
DownCastToGenerated<ForeignMessage>(mrf_message.Get(i, &scratch_space))
.c(),
Func(i, 6));
// Check sets through mutable objects.

Loading…
Cancel
Save