Use generic DeleteNode to reduce code size of `erase` in `Map` and to simplify the parsing logic in `MpMap`.

PiperOrigin-RevId: 704832360
pull/19535/head
Protobuf Team Bot 3 months ago committed by Copybara-Service
parent 4b397e522c
commit 828716eb57
  1. 40
      rust/cpp_kernel/map.cc
  2. 2
      src/google/protobuf/dynamic_message.cc
  3. 4
      src/google/protobuf/generated_message_tctable_impl.h
  4. 167
      src/google/protobuf/generated_message_tctable_lite.cc
  5. 63
      src/google/protobuf/map.h
  6. 8
      src/google/protobuf/map_test.inc

@ -7,6 +7,7 @@
#include <utility>
#include "absl/functional/overload.h"
#include "absl/log/absl_log.h"
#include "absl/strings/string_view.h"
#include "google/protobuf/message.h"
#include "google/protobuf/message_lite.h"
@ -48,6 +49,13 @@ template <typename Key>
using KeyMap = internal::KeyMapBase<
internal::KeyForBase<typename FromViewType<Key>::type>>;
template <typename T>
T AsViewType(T t) {
return t;
}
absl::string_view AsViewType(PtrAndLen key) { return key.AsStringView(); }
void InitializeMessageValue(void* raw_ptr, MessageLite* msg) {
MessageLite* new_msg = internal::RustMapHelper::PlacementNew(msg, raw_ptr);
auto* full_msg = DynamicCastMessage<Message>(new_msg);
@ -86,28 +94,8 @@ bool Insert(internal::UntypedMapBase* m, Key key, MapValue value) {
},
});
node = internal::RustMapHelper::InsertOrReplaceNode(
return internal::RustMapHelper::InsertOrReplaceNode(
static_cast<KeyMap<Key>*>(m), node);
if (node == nullptr) {
return true;
}
internal::RustMapHelper::DeleteNode(m, node);
return false;
}
template <typename Map, typename Key,
typename = typename std::enable_if<
!std::is_same<Key, google::protobuf::rust::PtrAndLen>::value>::type>
internal::RustMapHelper::NodeAndBucket FindHelper(Map* m, Key key) {
return internal::RustMapHelper::FindHelper(
m, static_cast<internal::KeyForBase<Key>>(key));
}
template <typename Map>
internal::RustMapHelper::NodeAndBucket FindHelper(Map* m,
google::protobuf::rust::PtrAndLen key) {
return internal::RustMapHelper::FindHelper(
m, absl::string_view(key.ptr, key.len));
}
void PopulateMapValue(const internal::UntypedMapBase& map,
@ -147,7 +135,7 @@ void PopulateMapValue(const internal::UntypedMapBase& map,
template <typename Key>
bool Get(internal::UntypedMapBase* m, Key key, MapValue* value) {
auto* map_base = static_cast<KeyMap<Key>*>(m);
internal::RustMapHelper::NodeAndBucket result = FindHelper(map_base, key);
auto result = internal::RustMapHelper::FindHelper(map_base, AsViewType(key));
if (result.node == nullptr) {
return false;
}
@ -158,13 +146,7 @@ bool Get(internal::UntypedMapBase* m, Key key, MapValue* value) {
template <typename Key>
bool Remove(internal::UntypedMapBase* m, Key key) {
auto* map_base = static_cast<KeyMap<Key>*>(m);
internal::RustMapHelper::NodeAndBucket result = FindHelper(map_base, key);
if (result.node == nullptr) {
return false;
}
internal::RustMapHelper::EraseNoDestroy(map_base, result.bucket, result.node);
internal::RustMapHelper::DeleteNode(m, result.node);
return true;
return internal::RustMapHelper::EraseImpl(map_base, AsViewType(key));
}
template <typename Key>

@ -269,7 +269,7 @@ bool DynamicMapField::DeleteMapValueImpl(MapFieldBase& base,
if (self.arena() == nullptr) {
it->second.DeleteData();
}
self.map_.erase(it);
self.map_.EraseDynamic(it);
return true;
}

@ -1007,8 +1007,8 @@ class PROTOBUF_EXPORT TcParser final {
static void WriteMapEntryAsUnknown(MessageLite* msg,
const TcParseTableBase* table,
uint32_t tag, NodeBase* node,
MapAuxInfo map_info);
UntypedMapBase& map, uint32_t tag,
NodeBase* node, MapAuxInfo map_info);
static const char* ParseOneMapEntry(NodeBase* node, const char* ptr,
ParseContext* ctx,

@ -2600,63 +2600,49 @@ error:
PROTOBUF_MUSTTAIL return Error(PROTOBUF_TC_PARAM_NO_DATA_PASS);
}
static void SerializeMapKey(const NodeBase* node, MapTypeCard type_card,
static void SerializeMapKey(UntypedMapBase& map, NodeBase* node,
MapTypeCard type_card,
io::CodedOutputStream& coded_output) {
switch (type_card.wiretype()) {
case WireFormatLite::WIRETYPE_VARINT:
switch (type_card.cpp_type()) {
case MapTypeCard::kBool:
WireFormatLite::WriteBool(
1, static_cast<const KeyNode<bool>*>(node)->key(), &coded_output);
break;
case MapTypeCard::k32:
if (type_card.is_zigzag()) {
WireFormatLite::WriteSInt32(
1, static_cast<const KeyNode<uint32_t>*>(node)->key(),
&coded_output);
} else if (type_card.is_signed()) {
WireFormatLite::WriteInt32(
1, static_cast<const KeyNode<uint32_t>*>(node)->key(),
&coded_output);
} else {
WireFormatLite::WriteUInt32(
1, static_cast<const KeyNode<uint32_t>*>(node)->key(),
&coded_output);
}
break;
case MapTypeCard::k64:
if (type_card.is_zigzag()) {
WireFormatLite::WriteSInt64(
1, static_cast<const KeyNode<uint64_t>*>(node)->key(),
&coded_output);
} else if (type_card.is_signed()) {
WireFormatLite::WriteInt64(
1, static_cast<const KeyNode<uint64_t>*>(node)->key(),
&coded_output);
} else {
WireFormatLite::WriteUInt64(
1, static_cast<const KeyNode<uint64_t>*>(node)->key(),
&coded_output);
}
break;
default:
Unreachable();
}
map.VisitKey(node, //
absl::Overload{
[&](const bool* v) {
WireFormatLite::WriteBool(1, *v, &coded_output);
},
[&](const uint32_t* v) {
if (type_card.is_zigzag()) {
WireFormatLite::WriteSInt32(1, *v, &coded_output);
} else if (type_card.is_signed()) {
WireFormatLite::WriteInt32(1, *v, &coded_output);
} else {
WireFormatLite::WriteUInt32(1, *v, &coded_output);
}
},
[&](const uint64_t* v) {
if (type_card.is_zigzag()) {
WireFormatLite::WriteSInt64(1, *v, &coded_output);
} else if (type_card.is_signed()) {
WireFormatLite::WriteInt64(1, *v, &coded_output);
} else {
WireFormatLite::WriteUInt64(1, *v, &coded_output);
}
},
[](const void*) { Unreachable(); },
});
break;
case WireFormatLite::WIRETYPE_FIXED32:
WireFormatLite::WriteFixed32(
1, static_cast<const KeyNode<uint32_t>*>(node)->key(), &coded_output);
WireFormatLite::WriteFixed32(1, *map.GetKey<uint32_t>(node),
&coded_output);
break;
case WireFormatLite::WIRETYPE_FIXED64:
WireFormatLite::WriteFixed64(
1, static_cast<const KeyNode<uint64_t>*>(node)->key(), &coded_output);
WireFormatLite::WriteFixed64(1, *map.GetKey<uint64_t>(node),
&coded_output);
break;
case WireFormatLite::WIRETYPE_LENGTH_DELIMITED:
// We should never have a message here. They can only be values maps.
ABSL_DCHECK_EQ(+type_card.cpp_type(), +MapTypeCard::kString);
WireFormatLite::WriteString(
1, static_cast<const KeyNode<std::string>*>(node)->key(),
&coded_output);
WireFormatLite::WriteString(1, *map.GetKey<std::string>(node),
&coded_output);
break;
default:
Unreachable();
@ -2665,21 +2651,22 @@ static void SerializeMapKey(const NodeBase* node, MapTypeCard type_card,
void TcParser::WriteMapEntryAsUnknown(MessageLite* msg,
const TcParseTableBase* table,
uint32_t tag, NodeBase* node,
MapAuxInfo map_info) {
UntypedMapBase& map, uint32_t tag,
NodeBase* node, MapAuxInfo map_info) {
std::string serialized;
{
io::StringOutputStream string_output(&serialized);
io::CodedOutputStream coded_output(&string_output);
SerializeMapKey(node, map_info.key_type_card, coded_output);
SerializeMapKey(map, node, map_info.key_type_card, coded_output);
// The mapped_type is always an enum here.
ABSL_DCHECK(map_info.value_is_validated_enum);
WireFormatLite::WriteInt32(2,
*reinterpret_cast<int32_t*>(
node->GetVoidValue(map_info.node_size_info)),
&coded_output);
WireFormatLite::WriteInt32(2, *map.GetValue<int32_t>(node), &coded_output);
}
GetUnknownFieldOps(table).write_length_delimited(msg, tag >> 3, serialized);
if (map.arena() == nullptr) {
map.DeleteNode(node);
}
}
template <typename T>
@ -2865,51 +2852,41 @@ PROTOBUF_NOINLINE const char* TcParser::MpMap(PROTOBUF_TC_PARAM_DECL) {
return ParseOneMapEntry(node, ptr, ctx, aux, table, entry, map.arena());
});
if (ABSL_PREDICT_TRUE(ptr != nullptr)) {
if (ABSL_PREDICT_FALSE(map_info.value_is_validated_enum &&
!internal::ValidateEnumInlined(
*static_cast<int32_t*>(node->GetVoidValue(
map_info.node_size_info)),
aux[1].enum_data))) {
WriteMapEntryAsUnknown(msg, table, saved_tag, node, map_info);
} else {
// Done parsing the node, try to insert it.
// If it overwrites something we get old node back to destroy it.
switch (map_info.key_type_card.cpp_type()) {
case MapTypeCard::kBool:
node = static_cast<KeyMapBase<bool>&>(map).InsertOrReplaceNode(
static_cast<KeyMapBase<bool>::KeyNode*>(node));
break;
case MapTypeCard::k32:
node = static_cast<KeyMapBase<uint32_t>&>(map).InsertOrReplaceNode(
static_cast<KeyMapBase<uint32_t>::KeyNode*>(node));
break;
case MapTypeCard::k64:
node = static_cast<KeyMapBase<uint64_t>&>(map).InsertOrReplaceNode(
static_cast<KeyMapBase<uint64_t>::KeyNode*>(node));
break;
case MapTypeCard::kString:
node =
static_cast<KeyMapBase<std::string>&>(map).InsertOrReplaceNode(
static_cast<KeyMapBase<std::string>::KeyNode*>(node));
break;
default:
Unreachable();
}
}
}
// Destroy the node if we have it.
// It could be because we failed to parse, or because insertion returned
// an overwritten node.
if (ABSL_PREDICT_FALSE(node != nullptr && map.arena() == nullptr)) {
map.DeleteNode(node);
}
if (ABSL_PREDICT_FALSE(ptr == nullptr)) {
// Parsing failed. Delete the node that we didn't insert.
if (map.arena() == nullptr) map.DeleteNode(node);
PROTOBUF_MUSTTAIL return Error(PROTOBUF_TC_PARAM_NO_DATA_PASS);
}
if (ABSL_PREDICT_FALSE(
map_info.value_is_validated_enum &&
!internal::ValidateEnumInlined(*map.GetValue<int32_t>(node),
aux[1].enum_data))) {
WriteMapEntryAsUnknown(msg, table, map, saved_tag, node, map_info);
} else {
// Done parsing the node, insert it.
switch (map_info.key_type_card.cpp_type()) {
case MapTypeCard::kBool:
static_cast<KeyMapBase<bool>&>(map).InsertOrReplaceNode(
static_cast<KeyMapBase<bool>::KeyNode*>(node));
break;
case MapTypeCard::k32:
static_cast<KeyMapBase<uint32_t>&>(map).InsertOrReplaceNode(
static_cast<KeyMapBase<uint32_t>::KeyNode*>(node));
break;
case MapTypeCard::k64:
static_cast<KeyMapBase<uint64_t>&>(map).InsertOrReplaceNode(
static_cast<KeyMapBase<uint64_t>::KeyNode*>(node));
break;
case MapTypeCard::kString:
static_cast<KeyMapBase<std::string>&>(map).InsertOrReplaceNode(
static_cast<KeyMapBase<std::string>::KeyNode*>(node));
break;
default:
Unreachable();
}
}
if (ABSL_PREDICT_FALSE(!ctx->DataAvailable(ptr))) {
PROTOBUF_MUSTTAIL return ToParseLoop(PROTOBUF_TC_PARAM_NO_DATA_PASS);
}

@ -680,7 +680,8 @@ class KeyMapBase : public UntypedMapBase {
friend class RustMapHelper;
friend class v2::TableDriven;
PROTOBUF_NOINLINE void erase_no_destroy(map_index_t b, KeyNode* node) {
PROTOBUF_NOINLINE size_type EraseImpl(map_index_t b, KeyNode* node,
bool do_destroy) {
// Force bucket_index to be in range.
b &= (num_buckets_ - 1);
@ -708,6 +709,20 @@ class KeyMapBase : public UntypedMapBase {
++index_of_first_non_null_;
}
}
if (arena() == nullptr && do_destroy) {
DeleteNode(node);
}
// To allow for the other overload of EraseImpl to do a tail call.
return 1;
}
PROTOBUF_NOINLINE size_type EraseImpl(typename TS::ViewType k) {
if (auto result = FindHelper(k); result.node != nullptr) {
return EraseImpl(result.bucket, static_cast<KeyNode*>(result.node), true);
}
return 0;
}
NodeAndBucket FindHelper(typename TS::ViewType k) const {
@ -721,22 +736,20 @@ class KeyMapBase : public UntypedMapBase {
}
// Insert the given node.
// If the key is a duplicate, it inserts the new node and returns the old one.
// Gives ownership to the caller.
// If the key is unique, it returns `nullptr`.
KeyNode* InsertOrReplaceNode(KeyNode* node) {
KeyNode* to_erase = nullptr;
// If the key is a duplicate, it inserts the new node and deletes the old one.
bool InsertOrReplaceNode(KeyNode* node) {
bool is_new = true;
auto p = this->FindHelper(node->key());
map_index_t b = p.bucket;
if (p.node != nullptr) {
erase_no_destroy(p.bucket, static_cast<KeyNode*>(p.node));
to_erase = static_cast<KeyNode*>(p.node);
if (ABSL_PREDICT_FALSE(p.node != nullptr)) {
EraseImpl(p.bucket, static_cast<KeyNode*>(p.node), true);
is_new = false;
} else if (ResizeIfLoadIsOutOfRange(num_elements_ + 1)) {
b = BucketNumber(node->key()); // bucket_number
}
InsertUnique(b, node);
++num_elements_;
return to_erase;
return is_new;
}
// Insert the given Node in bucket b. If that would make bucket b too big,
@ -876,13 +889,13 @@ class RustMapHelper {
}
template <typename Map>
static typename Map::KeyNode* InsertOrReplaceNode(Map* m, NodeBase* node) {
static bool InsertOrReplaceNode(Map* m, NodeBase* node) {
return m->InsertOrReplaceNode(static_cast<typename Map::KeyNode*>(node));
}
template <typename Map>
static void EraseNoDestroy(Map* m, map_index_t bucket, NodeBase* node) {
m->erase_no_destroy(bucket, static_cast<typename Map::KeyNode*>(node));
template <typename Map, typename Key>
static bool EraseImpl(Map* m, const Key& key) {
return m->EraseImpl(key);
}
static google::protobuf::MessageLite* PlacementNew(const MessageLite* prototype,
@ -1296,21 +1309,13 @@ class Map : private internal::KeyMapBase<internal::KeyForBase<Key>> {
// Erase and clear
template <typename K = key_type>
size_type erase(const key_arg<K>& key) {
iterator it = find(key);
if (it == end()) {
return 0;
} else {
erase(it);
return 1;
}
return this->EraseImpl(TS::ToView(key));
}
iterator erase(iterator pos) ABSL_ATTRIBUTE_LIFETIME_BOUND {
auto next = std::next(pos);
ABSL_DCHECK_EQ(pos.m_, static_cast<Base*>(this));
auto* node = static_cast<Node*>(pos.node_);
this->erase_no_destroy(pos.bucket_index_, node);
DeleteNode(node);
this->EraseImpl(pos.bucket_index_, static_cast<Node*>(pos.node_), true);
return next;
}
@ -1429,6 +1434,15 @@ class Map : private internal::KeyMapBase<internal::KeyForBase<Key>> {
true);
}
// For DynamicMapField, which needs a special destructor.
void EraseDynamic(iterator it) {
this->EraseImpl(it.bucket_index_,
static_cast<typename Map::KeyNode*>(it.node_), false);
if (this->arena() == nullptr) {
delete static_cast<Node*>(it.node_);
}
}
using Base::arena;
friend class Arena;
@ -1438,6 +1452,7 @@ class Map : private internal::KeyMapBase<internal::KeyForBase<Key>> {
using DestructorSkippable_ = void;
template <typename K, typename V>
friend class internal::MapFieldLite;
friend class internal::DynamicMapField;
friend class internal::TcParser;
friend struct internal::MapTestPeer;
friend struct internal::MapBenchmarkPeer;

@ -183,13 +183,7 @@ struct MapTestPeer {
using Node = typename T::Node;
auto* node = static_cast<Node*>(map.AllocNode(sizeof(Node)));
::new (static_cast<void*>(&node->kv)) typename T::value_type{key, value};
node = static_cast<Node*>(GetKeyMapBase(map).InsertOrReplaceNode(node));
if (node) {
node->~Node();
GetKeyMapBase(map).DeallocNode(node, sizeof(Node));
return false;
}
return true;
return GetKeyMapBase(map).InsertOrReplaceNode(node);
}
template <typename T>

Loading…
Cancel
Save