diff --git a/upbc/generator.cc b/upbc/generator.cc index dbbccf0a48..08f488fbea 100644 --- a/upbc/generator.cc +++ b/upbc/generator.cc @@ -704,25 +704,38 @@ int TableDescriptorType(const protobuf::FieldDescriptor* field) { } struct SubmsgArray { - std::vector messages; - absl::flat_hash_map indexes; -}; - -SubmsgArray GetSubmsgArray(const protobuf::Descriptor* message) { - SubmsgArray ret; - MessageLayout layout(message); - std::vector sorted_submsgs = - SortedSubmessages(message); - int i = 0; - for (auto submsg : sorted_submsgs) { - if (ret.indexes.find(submsg->message_type()) != ret.indexes.end()) { - continue; + public: + SubmsgArray(const protobuf::Descriptor* message) : message_(message) { + MessageLayout layout(message); + std::vector sorted_submsgs = + SortedSubmessages(message); + int i = 0; + for (auto submsg : sorted_submsgs) { + if (indexes_.find(submsg->message_type()) != indexes_.end()) { + continue; + } + submsgs_.push_back(submsg->message_type()); + indexes_[submsg->message_type()] = i++; } - ret.messages.push_back(submsg->message_type()); - ret.indexes[submsg->message_type()] = i++; } - return ret; -} + + const std::vector& submsgs() const { + return submsgs_; + } + + int GetIndex(const protobuf::FieldDescriptor* field) { + (void)message_; + assert(field->containing_type() == message_); + auto it = indexes_.find(field->message_type()); + assert(it != indexes_.end()); + return it->second; + } + + private: + const protobuf::Descriptor* message_; + std::vector submsgs_; + absl::flat_hash_map indexes_; +}; void WriteSource(const protobuf::FileDescriptor* file, Output& output) { EmitFileWarning(file, output); @@ -748,17 +761,17 @@ void WriteSource(const protobuf::FileDescriptor* file, Output& output) { std::string fields_array_ref = "NULL"; std::string submsgs_array_ref = "NULL"; MessageLayout layout(message); - SubmsgArray submsg_array = GetSubmsgArray(message); + SubmsgArray submsg_array(message); - if (!submsg_array.messages.empty()) { + if (!submsg_array.submsgs().empty()) { // TODO(haberman): could save a little bit of space by only generating a // "submsgs" array for every strongly-connected component. std::string submsgs_array_name = msgname + "_submsgs"; submsgs_array_ref = "&" + submsgs_array_name + "[0]"; output("static const upb_msglayout *const $0[$1] = {\n", - submsgs_array_name, submsg_array.messages.size()); + submsgs_array_name, submsg_array.submsgs().size()); - for (auto submsg : submsg_array.messages) { + for (auto submsg : submsg_array.submsgs()) { output(" &$0,\n", MessageInit(submsg)); } @@ -777,7 +790,7 @@ void WriteSource(const protobuf::FileDescriptor* file, Output& output) { std::string presence = "0"; if (field->cpp_type() == protobuf::FieldDescriptor::CPPTYPE_MESSAGE) { - submsg_index = submsg_array.indexes[field->message_type()]; + submsg_index = submsg_array.GetIndex(field); } if (MessageLayout::HasHasbit(field)) {