Introduce new API DownCastToMessage/DynamicCastToMessage as a replacement for down_cast/dynamic_cast from MessageLite to Message.

The new functions work even when RTTI is not present.

Also, add an assertion in DownCast to make sure we don't use it for message types. Message types have dedicated cast functions that should be used instead.

PiperOrigin-RevId: 633236522
pull/16819/head
Protobuf Team Bot 9 months ago committed by Copybara-Service
parent fad4736785
commit 7f23b700fe
  1. 3
      src/google/protobuf/compiler/cpp/message.cc
  2. 4
      src/google/protobuf/compiler/cpp/service.cc
  3. 4
      src/google/protobuf/compiler/cpp/unittest.inc
  4. 8
      src/google/protobuf/compiler/parser.cc
  5. 2
      src/google/protobuf/descriptor.cc
  6. 2
      src/google/protobuf/extension_set_heavy.cc
  7. 1
      src/google/protobuf/extension_set_unittest.cc
  8. 5
      src/google/protobuf/generated_message_tctable_full.cc
  9. 1
      src/google/protobuf/map_field.cc
  10. 30
      src/google/protobuf/map_test.inc
  11. 15
      src/google/protobuf/message.cc
  12. 41
      src/google/protobuf/message.h
  13. 5
      src/google/protobuf/message_lite.h
  14. 32
      src/google/protobuf/port.h
  15. 10
      src/google/protobuf/reflection_visit_field_info.h
  16. 4
      src/google/protobuf/reflection_visit_fields.h
  17. 13
      src/google/protobuf/repeated_field_reflection_unittest.inc

@ -3747,8 +3747,7 @@ void MessageGenerator::GenerateMergeFrom(io::Printer* p) {
format( format(
"void $classname$::CheckTypeAndMergeFrom(\n" "void $classname$::CheckTypeAndMergeFrom(\n"
" const ::$proto_ns$::MessageLite& from) {\n" " const ::$proto_ns$::MessageLite& from) {\n"
" MergeFrom(*::_pbi::DownCast<const $classname$*>(\n" " MergeFrom(::$proto_ns$::DownCastToGenerated<$classname$>(from));\n"
" &from));\n"
"}\n"); "}\n");
} }
} }

@ -256,8 +256,8 @@ void ServiceGenerator::GenerateCallMethodCases(io::Printer* printer) {
R"cc( R"cc(
case $index$: case $index$:
$name$(controller, $name$(controller,
::$proto_ns$::internal::DownCast<const $input$*>(request), ::$proto_ns$::DownCastToGenerated<$input$>(request),
::$proto_ns$::internal::DownCast<$output$*>(response), done); ::$proto_ns$::DownCastToGenerated<$output$>(response), done);
break; break;
)cc"); )cc");
} }

@ -1314,13 +1314,13 @@ TEST_F(GENERATED_SERVICE_TEST_NAME, CallMethodTypeFailure) {
EXPECT_DEBUG_DEATH( EXPECT_DEBUG_DEATH(
mock_service_.CallMethod(foo_, &mock_controller_, mock_service_.CallMethod(foo_, &mock_controller_,
&foo_request_, &bar_response_, done_.get()), &foo_request_, &bar_response_, done_.get()),
"dynamic_cast"); "DynamicCastToGenerated");
mock_service_.Reset(); mock_service_.Reset();
EXPECT_DEBUG_DEATH( EXPECT_DEBUG_DEATH(
mock_service_.CallMethod(foo_, &mock_controller_, mock_service_.CallMethod(foo_, &mock_controller_,
&bar_request_, &foo_response_, done_.get()), &bar_request_, &foo_response_, done_.get()),
"dynamic_cast"); "DynamicCastToGenerated");
#endif // GTEST_HAS_DEATH_TEST #endif // GTEST_HAS_DEATH_TEST
} }

@ -38,6 +38,7 @@
#include "google/protobuf/descriptor.pb.h" #include "google/protobuf/descriptor.pb.h"
#include "google/protobuf/io/strtod.h" #include "google/protobuf/io/strtod.h"
#include "google/protobuf/io/tokenizer.h" #include "google/protobuf/io/tokenizer.h"
#include "google/protobuf/message_lite.h"
#include "google/protobuf/port.h" #include "google/protobuf/port.h"
#include "google/protobuf/wire_format.h" #include "google/protobuf/wire_format.h"
@ -49,8 +50,6 @@ namespace protobuf {
namespace compiler { namespace compiler {
namespace { namespace {
using ::google::protobuf::internal::DownCast;
using TypeNameMap = using TypeNameMap =
absl::flat_hash_map<absl::string_view, FieldDescriptorProto::Type>; absl::flat_hash_map<absl::string_view, FieldDescriptorProto::Type>;
@ -1566,8 +1565,9 @@ bool Parser::ParseOption(Message* options,
} }
UninterpretedOption* uninterpreted_option = UninterpretedOption* uninterpreted_option =
DownCast<UninterpretedOption*>(options->GetReflection()->AddMessage( DownCastToGenerated<UninterpretedOption>(
options, uninterpreted_option_field)); options->GetReflection()->AddMessage(options,
uninterpreted_option_field));
// Parse dot-separated name. // Parse dot-separated name.
{ {

@ -8590,7 +8590,7 @@ bool DescriptorBuilder::OptionInterpreter::InterpretOptionsImpl(
*original_options, original_uninterpreted_options_field); *original_options, original_uninterpreted_options_field);
for (int i = 0; i < num_uninterpreted_options; ++i) { for (int i = 0; i < num_uninterpreted_options; ++i) {
src_path.push_back(i); src_path.push_back(i);
uninterpreted_option_ = DownCast<const UninterpretedOption*>( uninterpreted_option_ = DownCastToGenerated<UninterpretedOption>(
&original_options->GetReflection()->GetRepeatedMessage( &original_options->GetReflection()->GetRepeatedMessage(
*original_options, original_uninterpreted_options_field, i)); *original_options, original_uninterpreted_options_field, i));
if (!InterpretSingleOption(options, src_path, if (!InterpretSingleOption(options, src_path,

@ -406,7 +406,7 @@ size_t ExtensionSet::Extension::SpaceUsedExcludingSelfLong() const {
if (is_lazy) { if (is_lazy) {
total_size += lazymessage_value->SpaceUsedLong(); total_size += lazymessage_value->SpaceUsedLong();
} else { } else {
total_size += DownCast<Message*>(message_value)->SpaceUsedLong(); total_size += DownCastToMessage(message_value)->SpaceUsedLong();
} }
break; break;
default: default:

@ -26,6 +26,7 @@
#include "google/protobuf/dynamic_message.h" #include "google/protobuf/dynamic_message.h"
#include "google/protobuf/io/coded_stream.h" #include "google/protobuf/io/coded_stream.h"
#include "google/protobuf/io/zero_copy_stream_impl.h" #include "google/protobuf/io/zero_copy_stream_impl.h"
#include "google/protobuf/message_lite.h"
#include "google/protobuf/test_util.h" #include "google/protobuf/test_util.h"
#include "google/protobuf/test_util2.h" #include "google/protobuf/test_util2.h"
#include "google/protobuf/text_format.h" #include "google/protobuf/text_format.h"

@ -44,7 +44,6 @@
namespace google { namespace google {
namespace protobuf { namespace protobuf {
namespace internal { namespace internal {
using ::google::protobuf::internal::DownCast;
const char* TcParser::GenericFallback(PROTOBUF_TC_PARAM_DECL) { const char* TcParser::GenericFallback(PROTOBUF_TC_PARAM_DECL) {
PROTOBUF_MUSTTAIL return GenericFallbackImpl<Message, UnknownFieldSet>( PROTOBUF_MUSTTAIL return GenericFallbackImpl<Message, UnknownFieldSet>(
@ -64,7 +63,7 @@ const char* TcParser::ReflectionFallback(PROTOBUF_TC_PARAM_DECL) {
return ptr; return ptr;
} }
auto* full_msg = DownCast<Message*>(msg); auto* full_msg = DownCastToMessage(msg);
auto* descriptor = full_msg->GetDescriptor(); auto* descriptor = full_msg->GetDescriptor();
auto* reflection = full_msg->GetReflection(); auto* reflection = full_msg->GetReflection();
int field_number = WireFormatLite::GetTagFieldNumber(tag); int field_number = WireFormatLite::GetTagFieldNumber(tag);
@ -88,7 +87,7 @@ const char* TcParser::ReflectionParseLoop(PROTOBUF_TC_PARAM_DECL) {
(void)table; (void)table;
(void)hasbits; (void)hasbits;
// Call into the wire format reflective parse loop. // Call into the wire format reflective parse loop.
return WireFormat::_InternalParse(DownCast<Message*>(msg), ptr, ctx); return WireFormat::_InternalParse(DownCastToMessage(msg), ptr, ctx);
} }
const char* TcParser::MessageSetWireFormatParseLoop( const char* TcParser::MessageSetWireFormatParseLoop(

@ -21,7 +21,6 @@
namespace google { namespace google {
namespace protobuf { namespace protobuf {
namespace internal { namespace internal {
using ::google::protobuf::internal::DownCast;
VariantKey RealKeyToVariantKey<MapKey>::operator()(const MapKey& value) const { VariantKey RealKeyToVariantKey<MapKey>::operator()(const MapKey& value) const {
switch (value.type()) { switch (value.type()) {

@ -1578,7 +1578,7 @@ TEST_F(MapFieldReflectionTest, RegularFields) {
message_int32_message.GetReflection()->GetInt32( message_int32_message.GetReflection()->GetInt32(
message_int32_message, fd_map_int32_foreign_message_key); message_int32_message, fd_map_int32_foreign_message_key);
const ForeignMessage& value_int32_message = const ForeignMessage& value_int32_message =
DownCast<const ForeignMessage&>( DownCastToGenerated<ForeignMessage>(
message_int32_message.GetReflection()->GetMessage( message_int32_message.GetReflection()->GetMessage(
message_int32_message, fd_map_int32_foreign_message_value)); message_int32_message, fd_map_int32_foreign_message_value));
EXPECT_EQ(value_int32_message.c(), Func(key_int32_message, 6)); EXPECT_EQ(value_int32_message.c(), Func(key_int32_message, 6));
@ -1615,7 +1615,7 @@ TEST_F(MapFieldReflectionTest, RegularFields) {
message_int32_message.GetReflection()->GetInt32( message_int32_message.GetReflection()->GetInt32(
message_int32_message, fd_map_int32_foreign_message_key); message_int32_message, fd_map_int32_foreign_message_key);
const ForeignMessage& value_int32_message = const ForeignMessage& value_int32_message =
DownCast<const ForeignMessage&>( DownCastToGenerated<ForeignMessage>(
message_int32_message.GetReflection()->GetMessage( message_int32_message.GetReflection()->GetMessage(
message_int32_message, fd_map_int32_foreign_message_value)); message_int32_message, fd_map_int32_foreign_message_value));
EXPECT_EQ(value_int32_message.c(), Func(key_int32_message, 6)); EXPECT_EQ(value_int32_message.c(), Func(key_int32_message, 6));
@ -1652,7 +1652,7 @@ TEST_F(MapFieldReflectionTest, RegularFields) {
int32_t key_int32_message = int32_t key_int32_message =
message_int32_message->GetReflection()->GetInt32( message_int32_message->GetReflection()->GetInt32(
*message_int32_message, fd_map_int32_foreign_message_key); *message_int32_message, fd_map_int32_foreign_message_key);
ForeignMessage* value_int32_message = DownCast<ForeignMessage*>( ForeignMessage* value_int32_message = DownCastToGenerated<ForeignMessage>(
message_int32_message->GetReflection()->MutableMessage( message_int32_message->GetReflection()->MutableMessage(
message_int32_message, fd_map_int32_foreign_message_value)); message_int32_message, fd_map_int32_foreign_message_value));
value_int32_message->set_c(Func(key_int32_message, -6)); value_int32_message->set_c(Func(key_int32_message, -6));
@ -1808,7 +1808,7 @@ TEST_F(MapFieldReflectionTest, RepeatedFieldRefForRegularFields) {
message_int32_message.GetReflection()->GetInt32( message_int32_message.GetReflection()->GetInt32(
message_int32_message, fd_map_int32_foreign_message_key); message_int32_message, fd_map_int32_foreign_message_key);
const ForeignMessage& value_int32_message = const ForeignMessage& value_int32_message =
DownCast<const ForeignMessage&>( DownCastToGenerated<ForeignMessage>(
message_int32_message.GetReflection()->GetMessage( message_int32_message.GetReflection()->GetMessage(
message_int32_message, fd_map_int32_foreign_message_value)); message_int32_message, fd_map_int32_foreign_message_value));
EXPECT_EQ(value_int32_message.c(), Func(key_int32_message, 6)); EXPECT_EQ(value_int32_message.c(), Func(key_int32_message, 6));
@ -1849,7 +1849,7 @@ TEST_F(MapFieldReflectionTest, RepeatedFieldRefForRegularFields) {
message_int32_message.GetReflection()->GetInt32( message_int32_message.GetReflection()->GetInt32(
message_int32_message, fd_map_int32_foreign_message_key); message_int32_message, fd_map_int32_foreign_message_key);
const ForeignMessage& value_int32_message = const ForeignMessage& value_int32_message =
DownCast<const ForeignMessage&>( DownCastToGenerated<ForeignMessage>(
message_int32_message.GetReflection()->GetMessage( message_int32_message.GetReflection()->GetMessage(
message_int32_message, fd_map_int32_foreign_message_value)); message_int32_message, fd_map_int32_foreign_message_value));
EXPECT_EQ(value_int32_message.c(), Func(key_int32_message, 6)); EXPECT_EQ(value_int32_message.c(), Func(key_int32_message, 6));
@ -1966,8 +1966,8 @@ TEST_F(MapFieldReflectionTest, RepeatedFieldRefForRegularFields) {
const Message& message = *it; const Message& message = *it;
int32_t key = message.GetReflection()->GetInt32( int32_t key = message.GetReflection()->GetInt32(
message, fd_map_int32_foreign_message_key); message, fd_map_int32_foreign_message_key);
const ForeignMessage& sub_message = const ForeignMessage& sub_message = DownCastToGenerated<ForeignMessage>(
DownCast<const ForeignMessage&>(message.GetReflection()->GetMessage( message.GetReflection()->GetMessage(
message, fd_map_int32_foreign_message_value)); message, fd_map_int32_foreign_message_value));
result[key].MergeFrom(sub_message); result[key].MergeFrom(sub_message);
++index; ++index;
@ -2120,14 +2120,14 @@ TEST_F(MapFieldReflectionTest, RepeatedFieldRefForRegularFields) {
{ {
const Message& message0a = const Message& message0a =
mmf_int32_foreign_message.Get(0, entry_int32_foreign_message.get()); mmf_int32_foreign_message.Get(0, entry_int32_foreign_message.get());
const ForeignMessage& sub_message0a = const ForeignMessage& sub_message0a = DownCastToGenerated<ForeignMessage>(
DownCast<const ForeignMessage&>(message0a.GetReflection()->GetMessage( message0a.GetReflection()->GetMessage(
message0a, fd_map_int32_foreign_message_value)); message0a, fd_map_int32_foreign_message_value));
int32_t int32_value0a = sub_message0a.c(); int32_t int32_value0a = sub_message0a.c();
const Message& message9a = const Message& message9a =
mmf_int32_foreign_message.Get(9, entry_int32_foreign_message.get()); mmf_int32_foreign_message.Get(9, entry_int32_foreign_message.get());
const ForeignMessage& sub_message9a = const ForeignMessage& sub_message9a = DownCastToGenerated<ForeignMessage>(
DownCast<const ForeignMessage&>(message9a.GetReflection()->GetMessage( message9a.GetReflection()->GetMessage(
message9a, fd_map_int32_foreign_message_value)); message9a, fd_map_int32_foreign_message_value));
int32_t int32_value9a = sub_message9a.c(); int32_t int32_value9a = sub_message9a.c();
@ -2135,14 +2135,14 @@ TEST_F(MapFieldReflectionTest, RepeatedFieldRefForRegularFields) {
const Message& message0b = const Message& message0b =
mmf_int32_foreign_message.Get(0, entry_int32_foreign_message.get()); mmf_int32_foreign_message.Get(0, entry_int32_foreign_message.get());
const ForeignMessage& sub_message0b = const ForeignMessage& sub_message0b = DownCastToGenerated<ForeignMessage>(
DownCast<const ForeignMessage&>(message0b.GetReflection()->GetMessage( message0b.GetReflection()->GetMessage(
message0b, fd_map_int32_foreign_message_value)); message0b, fd_map_int32_foreign_message_value));
int32_t int32_value0b = sub_message0b.c(); int32_t int32_value0b = sub_message0b.c();
const Message& message9b = const Message& message9b =
mmf_int32_foreign_message.Get(9, entry_int32_foreign_message.get()); mmf_int32_foreign_message.Get(9, entry_int32_foreign_message.get());
const ForeignMessage& sub_message9b = const ForeignMessage& sub_message9b = DownCastToGenerated<ForeignMessage>(
DownCast<const ForeignMessage&>(message9b.GetReflection()->GetMessage( message9b.GetReflection()->GetMessage(
message9b, fd_map_int32_foreign_message_value)); message9b, fd_map_int32_foreign_message_value));
int32_t int32_value9b = sub_message9b.c(); int32_t int32_value9b = sub_message9b.c();

@ -63,12 +63,11 @@ void RegisterFileLevelMetadata(const DescriptorTable* descriptor_table);
} // namespace internal } // namespace internal
using internal::DownCast;
using internal::ReflectionOps; using internal::ReflectionOps;
using internal::WireFormat; using internal::WireFormat;
void Message::MergeImpl(MessageLite& to, const MessageLite& from) { void Message::MergeImpl(MessageLite& to, const MessageLite& from) {
ReflectionOps::Merge(DownCast<const Message&>(from), DownCast<Message*>(&to)); ReflectionOps::Merge(DownCastToMessage(from), DownCastToMessage(&to));
} }
void Message::MergeFrom(const Message& from) { void Message::MergeFrom(const Message& from) {
@ -82,7 +81,7 @@ void Message::MergeFrom(const Message& from) {
} }
void Message::CheckTypeAndMergeFrom(const MessageLite& other) { void Message::CheckTypeAndMergeFrom(const MessageLite& other) {
MergeFrom(*DownCast<const Message*>(&other)); MergeFrom(DownCastToMessage(other));
} }
void Message::CopyFrom(const Message& from) { void Message::CopyFrom(const Message& from) {
@ -113,7 +112,7 @@ void Message::CopyFrom(const Message& from) {
void Message::Clear() { ReflectionOps::Clear(this); } void Message::Clear() { ReflectionOps::Clear(this); }
bool Message::IsInitializedImpl(const MessageLite& msg) { bool Message::IsInitializedImpl(const MessageLite& msg) {
return ReflectionOps::IsInitialized(DownCast<const Message&>(msg)); return ReflectionOps::IsInitialized(DownCastToMessage(msg));
} }
void Message::FindInitializationErrors(std::vector<std::string>* errors) const { void Message::FindInitializationErrors(std::vector<std::string>* errors) const {
@ -189,20 +188,20 @@ size_t Message::SpaceUsedLong() const {
} }
static std::string GetTypeNameImpl(const MessageLite& msg) { static std::string GetTypeNameImpl(const MessageLite& msg) {
return DownCast<const Message&>(msg).GetDescriptor()->full_name(); return DownCastToMessage(msg).GetDescriptor()->full_name();
} }
static std::string InitializationErrorStringImpl(const MessageLite& msg) { static std::string InitializationErrorStringImpl(const MessageLite& msg) {
return DownCast<const Message&>(msg).InitializationErrorString(); return DownCastToMessage(msg).InitializationErrorString();
} }
const internal::TcParseTableBase* Message::GetTcParseTableImpl( const internal::TcParseTableBase* Message::GetTcParseTableImpl(
const MessageLite& msg) { const MessageLite& msg) {
return DownCast<const Message&>(msg).GetReflection()->GetTcParseTable(); return DownCastToMessage(msg).GetReflection()->GetTcParseTable();
} }
size_t Message::SpaceUsedLongImpl(const MessageLite& msg_lite) { size_t Message::SpaceUsedLongImpl(const MessageLite& msg_lite) {
auto& msg = DownCast<const Message&>(msg_lite); auto& msg = DownCastToMessage(msg_lite);
return msg.GetReflection()->SpaceUsedLong(msg); return msg.GetReflection()->SpaceUsedLong(msg);
} }

@ -1449,6 +1449,47 @@ void LinkMessageReflection() {
internal::StrongReferenceToType<T>(); internal::StrongReferenceToType<T>();
} }
// Tries to downcast this message from MessageLite to Message. Returns nullptr
// if this class is not an instance of Message. eg if the message was defined
// with optimized_for=LITE_RUNTIME. This works even if RTTI is disabled.
inline const Message* DynamicCastToMessage(const MessageLite* lite) {
return lite == nullptr || internal::GetClassData(*lite)->is_lite
? nullptr
: static_cast<const Message*>(lite);
}
inline Message* DynamicCastToMessage(MessageLite* lite) {
return const_cast<Message*>(
DynamicCastToMessage(static_cast<const MessageLite*>(lite)));
}
inline const Message& DynamicCastToMessage(const MessageLite& lite) {
auto* res = DynamicCastToMessage(&lite);
ABSL_CHECK(res != nullptr)
<< "Cannot to `Message` type " << lite.GetTypeName();
return *res;
}
inline Message& DynamicCastToMessage(MessageLite& lite) {
return const_cast<Message&>(
DynamicCastToMessage(static_cast<const MessageLite&>(lite)));
}
// A lightweight function for downcasting a MessageLite to Message. It should
// only be used when the caller is certain that the argument is a Message
// object.
inline const Message* DownCastToMessage(const MessageLite* lite) {
ABSL_CHECK(lite == nullptr || DynamicCastToMessage(lite) != nullptr);
return static_cast<const Message*>(lite);
}
inline Message* DownCastToMessage(MessageLite* lite) {
return const_cast<Message*>(
DownCastToMessage(static_cast<const MessageLite*>(lite)));
}
inline const Message& DownCastToMessage(const MessageLite& lite) {
return *DownCastToMessage(&lite);
}
inline Message& DownCastToMessage(MessageLite& lite) {
return *DownCastToMessage(&lite);
}
// ============================================================================= // =============================================================================
// Implementation details for {Get,Mutable}RawRepeatedPtrField. We provide // Implementation details for {Get,Mutable}RawRepeatedPtrField. We provide
// specializations for <std::string>, <StringPieceField> and <Message> and // specializations for <std::string>, <StringPieceField> and <Message> and

@ -124,6 +124,7 @@ class PROTOBUF_EXPORT CachedSize {
// For MessageLite to friend. // For MessageLite to friend.
class TypeId; class TypeId;
auto GetClassData(const MessageLite& msg);
class SwapFieldHelper; class SwapFieldHelper;
@ -705,6 +706,8 @@ class PROTOBUF_EXPORT MessageLite {
template <typename Type> template <typename Type>
friend class internal::GenericTypeHandler; friend class internal::GenericTypeHandler;
friend auto internal::GetClassData(const MessageLite& msg);
void LogInitializationErrorMessage() const; void LogInitializationErrorMessage() const;
bool MergeFromImpl(io::CodedInputStream* input, ParseFlags parse_flags); bool MergeFromImpl(io::CodedInputStream* input, ParseFlags parse_flags);
@ -758,6 +761,8 @@ class TypeId {
const MessageLite::ClassData* data_; const MessageLite::ClassData* data_;
}; };
inline auto GetClassData(const MessageLite& msg) { return msg.GetClassData(); }
template <bool alias> template <bool alias>
bool MergeFromImpl(absl::string_view input, MessageLite* msg, bool MergeFromImpl(absl::string_view input, MessageLite* msg,
const internal::TcParseTableBase* tc_table, const internal::TcParseTableBase* tc_table,

@ -159,28 +159,34 @@ struct ArenaInitialized {
}; };
template <typename To, typename From> template <typename To, typename From>
inline To DownCast(From* f) { void AssertDownCast(From* from) {
static_assert( static_assert(std::is_base_of<From, To>::value, "illegal DownCast");
std::is_base_of<From, typename std::remove_pointer<To>::type>::value,
"illegal DownCast"); #if defined(__cpp_concepts)
// Check that this function is not used to downcast message types.
// For those we should use {Down,Dynamic}CastTo{Message,Generated}.
static_assert(!requires {
std::derived_from<std::remove_pointer_t<To>,
typename std::remove_pointer_t<To>::MessageLite>;
});
#endif
#if PROTOBUF_RTTI #if PROTOBUF_RTTI
// RTTI: debug mode only! // RTTI: debug mode only!
assert(f == nullptr || dynamic_cast<To>(f) != nullptr); assert(from == nullptr || dynamic_cast<To*>(from) != nullptr);
#endif #endif
}
template <typename To, typename From>
inline To DownCast(From* f) {
AssertDownCast<std::remove_pointer_t<To>>(f);
return static_cast<To>(f); return static_cast<To>(f);
} }
template <typename ToRef, typename From> template <typename ToRef, typename From>
inline ToRef DownCast(From& f) { inline ToRef DownCast(From& f) {
using To = typename std::remove_reference<ToRef>::type; AssertDownCast<std::remove_reference_t<ToRef>>(&f);
static_assert(std::is_base_of<From, To>::value, "illegal DownCast"); return static_cast<ToRef>(f);
#if PROTOBUF_RTTI
// RTTI: debug mode only!
assert(dynamic_cast<To*>(&f) != nullptr);
#endif
return *static_cast<To*>(&f);
} }
// Looks up the name of `T` via RTTI, if RTTI is available. // Looks up the name of `T` via RTTI, if RTTI is available.

@ -200,10 +200,10 @@ struct DynamicExtensionInfoHelper {
} }
static const Message& GetMessage(const Extension& ext) { static const Message& GetMessage(const Extension& ext) {
return DownCast<const Message&>(*ext.message_value); return DownCastToMessage(*ext.message_value);
} }
static Message& MutableMessage(Extension& ext) { static Message& MutableMessage(Extension& ext) {
return DownCast<Message&>(*ext.message_value); return DownCastToMessage(*ext.message_value);
} }
static void ClearMessage(Extension& ext) { static void ClearMessage(Extension& ext) {
ext.is_cleared = true; ext.is_cleared = true;
@ -212,18 +212,18 @@ struct DynamicExtensionInfoHelper {
static const Message& GetLazyMessage(const Extension& ext, static const Message& GetLazyMessage(const Extension& ext,
const Message& prototype, Arena* arena) { const Message& prototype, Arena* arena) {
return DownCast<const Message&>( return DownCastToMessage(
ext.lazymessage_value->GetMessage(prototype, arena)); ext.lazymessage_value->GetMessage(prototype, arena));
} }
static const Message& GetLazyMessageIgnoreUnparsed(const Extension& ext, static const Message& GetLazyMessageIgnoreUnparsed(const Extension& ext,
const Message& prototype, const Message& prototype,
Arena* arena) { Arena* arena) {
return DownCast<const Message&>( return DownCastToMessage(
ext.lazymessage_value->GetMessageIgnoreUnparsed(prototype, arena)); ext.lazymessage_value->GetMessageIgnoreUnparsed(prototype, arena));
} }
static Message& MutableLazyMessage(Extension& ext, const Message& prototype, static Message& MutableLazyMessage(Extension& ext, const Message& prototype,
Arena* arena) { Arena* arena) {
return DownCast<Message&>( return DownCastToMessage(
*ext.lazymessage_value->MutableMessage(prototype, arena)); *ext.lazymessage_value->MutableMessage(prototype, arena));
} }
static void ClearLazyMessage(Extension& ext) { static void ClearLazyMessage(Extension& ext) {

@ -422,7 +422,7 @@ void ReflectionVisit::VisitMessageFields(const Message& message,
FieldDescriptor::CPPTYPE_MESSAGE) { FieldDescriptor::CPPTYPE_MESSAGE) {
if constexpr (info.is_repeated) { if constexpr (info.is_repeated) {
for (const auto& it : info.Get()) { for (const auto& it : info.Get()) {
func(DownCast<const Message&>(it)); func(DownCastToMessage(it));
} }
} else { } else {
func(info.Get()); func(info.Get());
@ -452,7 +452,7 @@ void ReflectionVisit::VisitMessageFields(Message& message, CallbackFn&& func) {
FieldDescriptor::CPPTYPE_MESSAGE) { FieldDescriptor::CPPTYPE_MESSAGE) {
if constexpr (info.is_repeated) { if constexpr (info.is_repeated) {
for (auto& it : info.Mutable()) { for (auto& it : info.Mutable()) {
func(DownCast<Message&>(it)); func(DownCastToMessage(it));
} }
} else { } else {
func(info.Mutable()); func(info.Mutable());

@ -91,7 +91,7 @@ TEST(REFLECTION_TEST, RegularFields) {
EXPECT_EQ(rf_double.Get(i), Func(i, 2)); EXPECT_EQ(rf_double.Get(i), Func(i, 2));
EXPECT_EQ(rpf_string.Get(i), StrFunc(i, 5)); EXPECT_EQ(rpf_string.Get(i), StrFunc(i, 5));
EXPECT_EQ(rpf_foreign_message.Get(i).c(), Func(i, 6)); EXPECT_EQ(rpf_foreign_message.Get(i).c(), Func(i, 6));
EXPECT_EQ(DownCast<const ForeignMessage*>(&rpf_message.Get(i))->c(), EXPECT_EQ(DownCastToGenerated<ForeignMessage>(&rpf_message.Get(i))->c(),
Func(i, 6)); Func(i, 6));
// Check gets through mutable objects. // Check gets through mutable objects.
@ -99,7 +99,7 @@ TEST(REFLECTION_TEST, RegularFields) {
EXPECT_EQ(mrf_double->Get(i), Func(i, 2)); EXPECT_EQ(mrf_double->Get(i), Func(i, 2));
EXPECT_EQ(mrpf_string->Get(i), StrFunc(i, 5)); EXPECT_EQ(mrpf_string->Get(i), StrFunc(i, 5));
EXPECT_EQ(mrpf_foreign_message->Get(i).c(), Func(i, 6)); EXPECT_EQ(mrpf_foreign_message->Get(i).c(), Func(i, 6));
EXPECT_EQ(DownCast<const ForeignMessage*>(&mrpf_message->Get(i))->c(), EXPECT_EQ(DownCastToGenerated<ForeignMessage>(&mrpf_message->Get(i))->c(),
Func(i, 6)); Func(i, 6));
// Check sets through mutable objects. // Check sets through mutable objects.
@ -111,7 +111,8 @@ TEST(REFLECTION_TEST, RegularFields) {
EXPECT_EQ(message.repeated_double(i), Func(i, -2)); EXPECT_EQ(message.repeated_double(i), Func(i, -2));
EXPECT_EQ(message.repeated_string(i), StrFunc(i, -5)); EXPECT_EQ(message.repeated_string(i), StrFunc(i, -5));
EXPECT_EQ(message.repeated_foreign_message(i).c(), Func(i, -6)); EXPECT_EQ(message.repeated_foreign_message(i).c(), Func(i, -6));
DownCast<ForeignMessage*>(mrpf_message->Mutable(i))->set_c(Func(i, 7)); DownCastToGenerated<ForeignMessage>(mrpf_message->Mutable(i))
->set_c(Func(i, 7));
EXPECT_EQ(message.repeated_foreign_message(i).c(), Func(i, 7)); EXPECT_EQ(message.repeated_foreign_message(i).c(), Func(i, 7));
} }
@ -271,7 +272,8 @@ TEST(REFLECTION_TEST, RepeatedFieldRefForRegularFields) {
ForeignMessage scratch_space; ForeignMessage scratch_space;
EXPECT_EQ(rf_foreign_message.Get(i, &scratch_space).c(), Func(i, 6)); EXPECT_EQ(rf_foreign_message.Get(i, &scratch_space).c(), Func(i, 6));
EXPECT_EQ( EXPECT_EQ(
DownCast<const ForeignMessage&>(rf_message.Get(i, &scratch_space)).c(), DownCastToGenerated<ForeignMessage>(rf_message.Get(i, &scratch_space))
.c(),
Func(i, 6)); Func(i, 6));
// Check gets through mutable objects. // Check gets through mutable objects.
@ -280,7 +282,8 @@ TEST(REFLECTION_TEST, RepeatedFieldRefForRegularFields) {
EXPECT_EQ(mrf_string.Get(i), StrFunc(i, 5)); EXPECT_EQ(mrf_string.Get(i), StrFunc(i, 5));
EXPECT_EQ(mrf_foreign_message.Get(i, &scratch_space).c(), Func(i, 6)); EXPECT_EQ(mrf_foreign_message.Get(i, &scratch_space).c(), Func(i, 6));
EXPECT_EQ( EXPECT_EQ(
DownCast<const ForeignMessage&>(mrf_message.Get(i, &scratch_space)).c(), DownCastToGenerated<ForeignMessage>(mrf_message.Get(i, &scratch_space))
.c(),
Func(i, 6)); Func(i, 6));
// Check sets through mutable objects. // Check sets through mutable objects.

Loading…
Cancel
Save