From 85c7cc24e0084a3beda37b848e7037f8f09c6850 Mon Sep 17 00:00:00 2001 From: Joshua Haberman Date: Tue, 4 Jun 2024 20:04:41 -0700 Subject: [PATCH] Add an indirection to sub-messages pointers to allow for static tree shaking. PiperOrigin-RevId: 640369522 --- upb/mini_descriptor/decode.c | 9 ++- upb/mini_descriptor/link.c | 8 +- upb/mini_table/internal/message.h | 12 +-- upb/mini_table/internal/sub.h | 5 ++ upb/wire/decode.c | 89 +++++++++++++---------- upb/wire/encode.c | 46 +++++++----- upb_generator/protoc-gen-upb_minitable.cc | 25 +++++-- 7 files changed, 119 insertions(+), 75 deletions(-) diff --git a/upb/mini_descriptor/decode.c b/upb/mini_descriptor/decode.c index d355112f57..1094fc0435 100644 --- a/upb/mini_descriptor/decode.c +++ b/upb/mini_descriptor/decode.c @@ -27,6 +27,7 @@ #include "upb/mini_table/field.h" #include "upb/mini_table/internal/field.h" #include "upb/mini_table/internal/message.h" +#include "upb/mini_table/internal/sub.h" #include "upb/mini_table/message.h" #include "upb/mini_table/sub.h" @@ -407,11 +408,15 @@ static void upb_MtDecoder_AllocateSubs(upb_MtDecoder* d, upb_SubCounts sub_counts) { uint32_t total_count = sub_counts.submsg_count + sub_counts.subenum_count; size_t subs_bytes = sizeof(*d->table->UPB_PRIVATE(subs)) * total_count; - upb_MiniTableSub* subs = upb_Arena_Malloc(d->arena, subs_bytes); + size_t ptrs_bytes = sizeof(upb_MiniTable*) * sub_counts.submsg_count; + upb_MiniTableSubInternal* subs = upb_Arena_Malloc(d->arena, subs_bytes); + const upb_MiniTable** subs_ptrs = upb_Arena_Malloc(d->arena, ptrs_bytes); upb_MdDecoder_CheckOutOfMemory(&d->base, subs); + upb_MdDecoder_CheckOutOfMemory(&d->base, subs_ptrs); uint32_t i = 0; for (; i < sub_counts.submsg_count; i++) { - subs[i].UPB_PRIVATE(submsg) = UPB_PRIVATE(_upb_MiniTable_Empty)(); + subs_ptrs[i] = UPB_PRIVATE(_upb_MiniTable_Empty)(); + subs[i].UPB_PRIVATE(submsg) = &subs_ptrs[i]; } if (sub_counts.subenum_count) { upb_MiniTableField* f = d->fields; diff --git a/upb/mini_descriptor/link.c b/upb/mini_descriptor/link.c index 093150b623..5dec59e324 100644 --- a/upb/mini_descriptor/link.c +++ b/upb/mini_descriptor/link.c @@ -9,10 +9,14 @@ #include #include +#include #include "upb/base/descriptor_constants.h" #include "upb/mini_table/enum.h" #include "upb/mini_table/field.h" +#include "upb/mini_table/internal/field.h" +#include "upb/mini_table/internal/message.h" +#include "upb/mini_table/internal/sub.h" #include "upb/mini_table/message.h" #include "upb/mini_table/sub.h" @@ -51,11 +55,11 @@ bool upb_MiniTable_SetSubMessage(upb_MiniTable* table, } int idx = field->UPB_PRIVATE(submsg_index); - upb_MiniTableSub* table_subs = (void*)table->UPB_PRIVATE(subs); + upb_MiniTableSubInternal* table_subs = (void*)table->UPB_PRIVATE(subs); // TODO: Add this assert back once YouTube is updated to not call // this function repeatedly. // UPB_ASSERT(UPB_PRIVATE(_upb_MiniTable_IsEmpty)(table_sub->submsg)); - table_subs[idx] = upb_MiniTableSub_FromMessage(sub); + memcpy((void*)table_subs[idx].UPB_PRIVATE(submsg), &sub, sizeof(void*)); return true; } diff --git a/upb/mini_table/internal/message.h b/upb/mini_table/internal/message.h index d5b1ae4e0b..2c618ce233 100644 --- a/upb/mini_table/internal/message.h +++ b/upb/mini_table/internal/message.h @@ -46,7 +46,7 @@ typedef enum { // LINT.IfChange(minitable_struct_definition) struct upb_MiniTable { - const union upb_MiniTableSub* UPB_PRIVATE(subs); + const upb_MiniTableSubInternal* UPB_PRIVATE(subs); const struct upb_MiniTableField* UPB_ONLYBITS(fields); // Must be aligned to sizeof(void*). Doesn't include internal members like @@ -99,9 +99,10 @@ UPB_API_INLINE const struct upb_MiniTableField* upb_MiniTable_GetFieldByIndex( return &m->UPB_ONLYBITS(fields)[i]; } -UPB_INLINE const union upb_MiniTableSub UPB_PRIVATE( - _upb_MiniTable_GetSubByIndex)(const struct upb_MiniTable* m, uint32_t i) { - return m->UPB_PRIVATE(subs)[i]; +UPB_INLINE const struct upb_MiniTable* UPB_PRIVATE( + _upb_MiniTable_GetSubTableByIndex)(const struct upb_MiniTable* m, + uint32_t i) { + return *m->UPB_PRIVATE(subs)[i].UPB_PRIVATE(submsg); } UPB_API_INLINE const struct upb_MiniTable* upb_MiniTable_SubMessage( @@ -109,7 +110,8 @@ UPB_API_INLINE const struct upb_MiniTable* upb_MiniTable_SubMessage( if (upb_MiniTableField_CType(f) != kUpb_CType_Message) { return NULL; } - return m->UPB_PRIVATE(subs)[f->UPB_PRIVATE(submsg_index)].UPB_PRIVATE(submsg); + return UPB_PRIVATE(_upb_MiniTable_GetSubTableByIndex)( + m, f->UPB_PRIVATE(submsg_index)); } UPB_API_INLINE const struct upb_MiniTable* upb_MiniTable_GetSubMessageTable( diff --git a/upb/mini_table/internal/sub.h b/upb/mini_table/internal/sub.h index 967b557dda..4c21569ba0 100644 --- a/upb/mini_table/internal/sub.h +++ b/upb/mini_table/internal/sub.h @@ -11,6 +11,11 @@ // Must be last. #include "upb/port/def.inc" +typedef union { + const struct upb_MiniTable* const* UPB_PRIVATE(submsg); + const struct upb_MiniTableEnum* UPB_PRIVATE(subenum); +} upb_MiniTableSubInternal; + union upb_MiniTableSub { const struct upb_MiniTable* UPB_PRIVATE(submsg); const struct upb_MiniTableEnum* UPB_PRIVATE(subenum); diff --git a/upb/wire/decode.c b/upb/wire/decode.c index 59602aee30..2cb1a44913 100644 --- a/upb/wire/decode.c +++ b/upb/wire/decode.c @@ -35,7 +35,7 @@ #include "upb/mini_table/field.h" #include "upb/mini_table/internal/field.h" #include "upb/mini_table/internal/message.h" -#include "upb/mini_table/internal/size_log2.h" +#include "upb/mini_table/internal/sub.h" #include "upb/mini_table/message.h" #include "upb/mini_table/sub.h" #include "upb/port/atomic.h" @@ -97,15 +97,15 @@ typedef union { // Returns the MiniTable corresponding to a given MiniTableField // from an array of MiniTableSubs. static const upb_MiniTable* _upb_MiniTableSubs_MessageByField( - const upb_MiniTableSub* subs, const upb_MiniTableField* field) { - return upb_MiniTableSub_Message(subs[field->UPB_PRIVATE(submsg_index)]); + const upb_MiniTableSubInternal* subs, const upb_MiniTableField* field) { + return *subs[field->UPB_PRIVATE(submsg_index)].UPB_PRIVATE(submsg); } // Returns the MiniTableEnum corresponding to a given MiniTableField // from an array of MiniTableSub. static const upb_MiniTableEnum* _upb_MiniTableSubs_EnumByField( - const upb_MiniTableSub* subs, const upb_MiniTableField* field) { - return upb_MiniTableSub_Enum(subs[field->UPB_PRIVATE(submsg_index)]); + const upb_MiniTableSubInternal* subs, const upb_MiniTableField* field) { + return subs[field->UPB_PRIVATE(submsg_index)].UPB_PRIVATE(subenum); } static const char* _upb_Decoder_DecodeMessage(upb_Decoder* d, const char* ptr, @@ -240,11 +240,10 @@ static void _upb_Decoder_Munge(int type, wireval* val) { } } -static upb_Message* _upb_Decoder_NewSubMessage(upb_Decoder* d, - const upb_MiniTableSub* subs, - const upb_MiniTableField* field, - upb_TaggedMessagePtr* target) { - const upb_MiniTable* subl = _upb_MiniTableSubs_MessageByField(subs, field); +static upb_Message* _upb_Decoder_NewSubMessage2(upb_Decoder* d, + const upb_MiniTable* subl, + const upb_MiniTableField* field, + upb_TaggedMessagePtr* target) { UPB_ASSERT(subl); upb_Message* msg = _upb_Message_New(subl, &d->arena); if (!msg) _upb_Decoder_ErrorJmp(d, kUpb_DecodeStatus_OutOfMemory); @@ -265,8 +264,15 @@ static upb_Message* _upb_Decoder_NewSubMessage(upb_Decoder* d, return msg; } +static upb_Message* _upb_Decoder_NewSubMessage( + upb_Decoder* d, const upb_MiniTableSubInternal* subs, + const upb_MiniTableField* field, upb_TaggedMessagePtr* target) { + const upb_MiniTable* subl = _upb_MiniTableSubs_MessageByField(subs, field); + return _upb_Decoder_NewSubMessage2(d, subl, field, target); +} + static upb_Message* _upb_Decoder_ReuseSubMessage( - upb_Decoder* d, const upb_MiniTableSub* subs, + upb_Decoder* d, const upb_MiniTableSubInternal* subs, const upb_MiniTableField* field, upb_TaggedMessagePtr* target) { upb_TaggedMessagePtr tagged = *target; const upb_MiniTable* subl = _upb_MiniTableSubs_MessageByField(subs, field); @@ -319,7 +325,7 @@ const char* _upb_Decoder_RecurseSubMessage(upb_Decoder* d, const char* ptr, UPB_FORCEINLINE const char* _upb_Decoder_DecodeSubMessage(upb_Decoder* d, const char* ptr, upb_Message* submsg, - const upb_MiniTableSub* subs, + const upb_MiniTableSubInternal* subs, const upb_MiniTableField* field, int size) { int saved_delta = upb_EpsCopyInputStream_PushLimit(&d->input, ptr, size); @@ -352,7 +358,7 @@ const char* _upb_Decoder_DecodeUnknownGroup(upb_Decoder* d, const char* ptr, UPB_FORCEINLINE const char* _upb_Decoder_DecodeKnownGroup(upb_Decoder* d, const char* ptr, upb_Message* submsg, - const upb_MiniTableSub* subs, + const upb_MiniTableSubInternal* subs, const upb_MiniTableField* field) { const upb_MiniTable* subl = _upb_MiniTableSubs_MessageByField(subs, field); UPB_ASSERT(subl); @@ -403,12 +409,10 @@ bool _upb_Decoder_CheckEnum(upb_Decoder* d, const char* ptr, upb_Message* msg, } UPB_NOINLINE -static const char* _upb_Decoder_DecodeEnumArray(upb_Decoder* d, const char* ptr, - upb_Message* msg, - upb_Array* arr, - const upb_MiniTableSub* subs, - const upb_MiniTableField* field, - wireval* val) { +static const char* _upb_Decoder_DecodeEnumArray( + upb_Decoder* d, const char* ptr, upb_Message* msg, upb_Array* arr, + const upb_MiniTableSubInternal* subs, const upb_MiniTableField* field, + wireval* val) { const upb_MiniTableEnum* e = _upb_MiniTableSubs_EnumByField(subs, field); if (!_upb_Decoder_CheckEnum(d, ptr, msg, e, field, val)) return ptr; void* mem = UPB_PTR_AT(upb_Array_MutableDataPtr(arr), @@ -484,7 +488,7 @@ const char* _upb_Decoder_DecodeVarintPacked(upb_Decoder* d, const char* ptr, UPB_NOINLINE static const char* _upb_Decoder_DecodeEnumPacked( upb_Decoder* d, const char* ptr, upb_Message* msg, upb_Array* arr, - const upb_MiniTableSub* subs, const upb_MiniTableField* field, + const upb_MiniTableSubInternal* subs, const upb_MiniTableField* field, wireval* val) { const upb_MiniTableEnum* e = _upb_MiniTableSubs_EnumByField(subs, field); int saved_limit = upb_EpsCopyInputStream_PushLimit(&d->input, ptr, val->size); @@ -518,11 +522,10 @@ static upb_Array* _upb_Decoder_CreateArray(upb_Decoder* d, return ret; } -static const char* _upb_Decoder_DecodeToArray(upb_Decoder* d, const char* ptr, - upb_Message* msg, - const upb_MiniTableSub* subs, - const upb_MiniTableField* field, - wireval* val, int op) { +static const char* _upb_Decoder_DecodeToArray( + upb_Decoder* d, const char* ptr, upb_Message* msg, + const upb_MiniTableSubInternal* subs, const upb_MiniTableField* field, + wireval* val, int op) { upb_Array** arrp = UPB_PTR_AT(msg, field->UPB_PRIVATE(offset), void); upb_Array* arr = *arrp; void* mem; @@ -623,11 +626,10 @@ static upb_Map* _upb_Decoder_CreateMap(upb_Decoder* d, return ret; } -static const char* _upb_Decoder_DecodeToMap(upb_Decoder* d, const char* ptr, - upb_Message* msg, - const upb_MiniTableSub* subs, - const upb_MiniTableField* field, - wireval* val) { +static const char* _upb_Decoder_DecodeToMap( + upb_Decoder* d, const char* ptr, upb_Message* msg, + const upb_MiniTableSubInternal* subs, const upb_MiniTableField* field, + wireval* val) { upb_Map** map_p = UPB_PTR_AT(msg, field->UPB_PRIVATE(offset), upb_Map*); upb_Map* map = *map_p; upb_MapEntry ent; @@ -688,8 +690,8 @@ static const char* _upb_Decoder_DecodeToMap(upb_Decoder* d, const char* ptr, static const char* _upb_Decoder_DecodeToSubMessage( upb_Decoder* d, const char* ptr, upb_Message* msg, - const upb_MiniTableSub* subs, const upb_MiniTableField* field, wireval* val, - int op) { + const upb_MiniTableSubInternal* subs, const upb_MiniTableField* field, + wireval* val, int op) { void* mem = UPB_PTR_AT(msg, field->UPB_PRIVATE(offset), void); int type = field->UPB_PRIVATE(descriptortype); @@ -819,9 +821,9 @@ static void upb_Decoder_AddKnownMessageSetItem( if (UPB_UNLIKELY(!ext)) { _upb_Decoder_ErrorJmp(d, kUpb_DecodeStatus_OutOfMemory); } - upb_Message* submsg = _upb_Decoder_NewSubMessage( - d, &ext->ext->UPB_PRIVATE(sub), &ext->ext->UPB_PRIVATE(field), - (upb_TaggedMessagePtr*)&ext->data); + upb_Message* submsg = _upb_Decoder_NewSubMessage2( + d, ext->ext->UPB_PRIVATE(sub).UPB_PRIVATE(submsg), + &ext->ext->UPB_PRIVATE(field), (upb_TaggedMessagePtr*)&ext->data); upb_DecodeStatus status = upb_Decode( data, size, submsg, upb_MiniTableExtension_GetSubMessage(item_mt), d->extreg, d->options, &d->arena); @@ -1022,8 +1024,9 @@ void _upb_Decoder_CheckUnlinked(upb_Decoder* d, const upb_MiniTable* mt, // unlinked. do { UPB_ASSERT(upb_MiniTableField_CType(oneof) == kUpb_CType_Message); - const upb_MiniTableSub* oneof_sub = - &mt->UPB_PRIVATE(subs)[oneof->UPB_PRIVATE(submsg_index)]; + const upb_MiniTable* oneof_sub = + *mt->UPB_PRIVATE(subs)[oneof->UPB_PRIVATE(submsg_index)].UPB_PRIVATE( + submsg); UPB_ASSERT(!oneof_sub); } while (upb_MiniTable_NextOneofField(mt, &oneof)); } @@ -1161,8 +1164,9 @@ const char* _upb_Decoder_DecodeKnownField(upb_Decoder* d, const char* ptr, const upb_MiniTable* layout, const upb_MiniTableField* field, int op, wireval* val) { - const upb_MiniTableSub* subs = layout->UPB_PRIVATE(subs); + const upb_MiniTableSubInternal* subs = layout->UPB_PRIVATE(subs); uint8_t mode = field->UPB_PRIVATE(mode); + upb_MiniTableSubInternal ext_sub; if (UPB_UNLIKELY(mode & kUpb_LabelFlags_IsExtension)) { const upb_MiniTableExtension* ext_layout = @@ -1174,7 +1178,14 @@ const char* _upb_Decoder_DecodeKnownField(upb_Decoder* d, const char* ptr, } d->unknown_msg = msg; msg = (upb_Message*)&ext->data; - subs = &ext->ext->UPB_PRIVATE(sub); + if (upb_MiniTableField_IsSubMessage(&ext->ext->UPB_PRIVATE(field))) { + ext_sub.UPB_PRIVATE(submsg) = + &ext->ext->UPB_PRIVATE(sub).UPB_PRIVATE(submsg); + } else { + ext_sub.UPB_PRIVATE(subenum) = + ext->ext->UPB_PRIVATE(sub).UPB_PRIVATE(subenum); + } + subs = &ext_sub; } switch (mode & kUpb_FieldMode_Mask) { diff --git a/upb/wire/encode.c b/upb/wire/encode.c index 0b6d7c32cb..5764199e44 100644 --- a/upb/wire/encode.c +++ b/upb/wire/encode.c @@ -35,14 +35,21 @@ #include "upb/mini_table/field.h" #include "upb/mini_table/internal/field.h" #include "upb/mini_table/internal/message.h" +#include "upb/mini_table/internal/sub.h" #include "upb/mini_table/message.h" -#include "upb/mini_table/sub.h" #include "upb/wire/internal/constants.h" #include "upb/wire/types.h" // Must be last. #include "upb/port/def.inc" +// Returns the MiniTable corresponding to a given MiniTableField +// from an array of MiniTableSubs. +static const upb_MiniTable* _upb_Encoder_GetSubMiniTable( + const upb_MiniTableSubInternal* subs, const upb_MiniTableField* field) { + return *subs[field->UPB_PRIVATE(submsg_index)].UPB_PRIVATE(submsg); +} + #define UPB_PB_VARINT_MAX_LEN 10 UPB_NOINLINE @@ -224,7 +231,7 @@ static void encode_TaggedMessagePtr(upb_encstate* e, } static void encode_scalar(upb_encstate* e, const void* _field_mem, - const upb_MiniTableSub* subs, + const upb_MiniTableSubInternal* subs, const upb_MiniTableField* f) { const char* field_mem = _field_mem; int wire_type; @@ -273,8 +280,7 @@ static void encode_scalar(upb_encstate* e, const void* _field_mem, case kUpb_FieldType_Group: { size_t size; upb_TaggedMessagePtr submsg = *(upb_TaggedMessagePtr*)field_mem; - const upb_MiniTable* subm = - upb_MiniTableSub_Message(subs[f->UPB_PRIVATE(submsg_index)]); + const upb_MiniTable* subm = _upb_Encoder_GetSubMiniTable(subs, f); if (submsg == 0) { return; } @@ -288,8 +294,7 @@ static void encode_scalar(upb_encstate* e, const void* _field_mem, case kUpb_FieldType_Message: { size_t size; upb_TaggedMessagePtr submsg = *(upb_TaggedMessagePtr*)field_mem; - const upb_MiniTable* subm = - upb_MiniTableSub_Message(subs[f->UPB_PRIVATE(submsg_index)]); + const upb_MiniTable* subm = _upb_Encoder_GetSubMiniTable(subs, f); if (submsg == 0) { return; } @@ -309,7 +314,7 @@ static void encode_scalar(upb_encstate* e, const void* _field_mem, } static void encode_array(upb_encstate* e, const upb_Message* msg, - const upb_MiniTableSub* subs, + const upb_MiniTableSubInternal* subs, const upb_MiniTableField* f) { const upb_Array* arr = *UPB_PTR_AT(msg, f->UPB_PRIVATE(offset), upb_Array*); bool packed = upb_MiniTableField_IsPacked(f); @@ -379,8 +384,7 @@ static void encode_array(upb_encstate* e, const upb_Message* msg, case kUpb_FieldType_Group: { const upb_TaggedMessagePtr* start = upb_Array_DataPtr(arr); const upb_TaggedMessagePtr* ptr = start + upb_Array_Size(arr); - const upb_MiniTable* subm = - upb_MiniTableSub_Message(subs[f->UPB_PRIVATE(submsg_index)]); + const upb_MiniTable* subm = _upb_Encoder_GetSubMiniTable(subs, f); if (--e->depth == 0) encode_err(e, kUpb_EncodeStatus_MaxDepthExceeded); do { size_t size; @@ -395,8 +399,7 @@ static void encode_array(upb_encstate* e, const upb_Message* msg, case kUpb_FieldType_Message: { const upb_TaggedMessagePtr* start = upb_Array_DataPtr(arr); const upb_TaggedMessagePtr* ptr = start + upb_Array_Size(arr); - const upb_MiniTable* subm = - upb_MiniTableSub_Message(subs[f->UPB_PRIVATE(submsg_index)]); + const upb_MiniTable* subm = _upb_Encoder_GetSubMiniTable(subs, f); if (--e->depth == 0) encode_err(e, kUpb_EncodeStatus_MaxDepthExceeded); do { size_t size; @@ -432,11 +435,10 @@ static void encode_mapentry(upb_encstate* e, uint32_t number, } static void encode_map(upb_encstate* e, const upb_Message* msg, - const upb_MiniTableSub* subs, + const upb_MiniTableSubInternal* subs, const upb_MiniTableField* f) { const upb_Map* map = *UPB_PTR_AT(msg, f->UPB_PRIVATE(offset), const upb_Map*); - const upb_MiniTable* layout = - upb_MiniTableSub_Message(subs[f->UPB_PRIVATE(submsg_index)]); + const upb_MiniTable* layout = _upb_Encoder_GetSubMiniTable(subs, f); UPB_ASSERT(upb_MiniTable_FieldCount(layout) == 2); if (!map || !upb_Map_Size(map)) return; @@ -465,7 +467,6 @@ static void encode_map(upb_encstate* e, const upb_Message* msg, } static bool encode_shouldencode(upb_encstate* e, const upb_Message* msg, - const upb_MiniTableSub* subs, const upb_MiniTableField* f) { if (f->presence == 0) { // Proto3 presence or map/array. @@ -504,7 +505,7 @@ static bool encode_shouldencode(upb_encstate* e, const upb_Message* msg, } static void encode_field(upb_encstate* e, const upb_Message* msg, - const upb_MiniTableSub* subs, + const upb_MiniTableSubInternal* subs, const upb_MiniTableField* field) { switch (UPB_PRIVATE(_upb_MiniTableField_Mode)(field)) { case kUpb_FieldMode_Array: @@ -539,7 +540,14 @@ static void encode_ext(upb_encstate* e, const upb_Extension* ext, if (UPB_UNLIKELY(is_message_set)) { encode_msgset_item(e, ext); } else { - encode_field(e, (upb_Message*)&ext->data, &ext->ext->UPB_PRIVATE(sub), + upb_MiniTableSubInternal sub; + if (upb_MiniTableField_IsSubMessage(&ext->ext->UPB_PRIVATE(field))) { + sub.UPB_PRIVATE(submsg) = &ext->ext->UPB_PRIVATE(sub).UPB_PRIVATE(submsg); + } else { + sub.UPB_PRIVATE(subenum) = + ext->ext->UPB_PRIVATE(sub).UPB_PRIVATE(subenum); + } + encode_field(e, (upb_Message*)&ext->data, &sub, &ext->ext->UPB_PRIVATE(field)); } } @@ -595,7 +603,7 @@ static void encode_message(upb_encstate* e, const upb_Message* msg, const upb_MiniTableField* first = &m->UPB_PRIVATE(fields)[0]; while (f != first) { f--; - if (encode_shouldencode(e, msg, m->UPB_PRIVATE(subs), f)) { + if (encode_shouldencode(e, msg, f)) { encode_field(e, msg, m->UPB_PRIVATE(subs), f); } } @@ -682,4 +690,4 @@ const char* upb_EncodeStatus_String(upb_EncodeStatus status) { default: return "Unknown encode status"; } -} \ No newline at end of file +} diff --git a/upb_generator/protoc-gen-upb_minitable.cc b/upb_generator/protoc-gen-upb_minitable.cc index 7c15308663..9e80e19583 100644 --- a/upb_generator/protoc-gen-upb_minitable.cc +++ b/upb_generator/protoc-gen-upb_minitable.cc @@ -70,6 +70,10 @@ std::string ExtensionLayout(upb::FieldDefPtr ext) { return absl::StrCat(ExtensionIdentBase(ext), "_", ext.name(), "_ext"); } +std::string MessagePtrName(upb::MessageDefPtr message) { + return MessageInitName(message) + "_ptr"; +} + const char* kEnumsInit = "enums_layout"; const char* kExtensionsInit = "extensions_layout"; const char* kMessagesInit = "messages_layout"; @@ -312,10 +316,11 @@ void WriteMessageField(upb::FieldDefPtr field, output(" $0,\n", upb::generator::FieldInitializer(field, field64, field32)); } -std::string GetSub(upb::FieldDefPtr field) { +std::string GetSub(upb::FieldDefPtr field, bool is_extension) { if (auto message_def = field.message_type()) { return absl::Substitute("{.UPB_PRIVATE(submsg) = &$0}", - MessageInitName(message_def)); + is_extension ? MessageInitName(message_def) + : MessagePtrName(message_def)); } if (auto enum_def = field.enum_subdef()) { @@ -345,17 +350,18 @@ void WriteMessage(upb::MessageDefPtr message, const DefPoolPair& pools, uint32_t index = f->UPB_PRIVATE(submsg_index); if (index != kUpb_NoSub) { const int f_number = upb_MiniTableField_Number(f); - auto pair = - subs.emplace(index, GetSub(message.FindFieldByNumber(f_number))); + upb::FieldDefPtr field = message.FindFieldByNumber(f_number); + auto pair = subs.emplace(index, GetSub(field, false)); ABSL_CHECK(pair.second); } } - // Write upb_MiniTableSub table for sub messages referenced from fields. + // Write upb_MiniTableSubInternal table for sub messages referenced from + // fields. if (!subs.empty()) { std::string submsgs_array_name = msg_name + "_submsgs"; submsgs_array_ref = "&" + submsgs_array_name + "[0]"; - output("static const upb_MiniTableSub $0[$1] = {\n", submsgs_array_name, - subs.size()); + output("static const upb_MiniTableSubInternal $0[$1] = {\n", + submsgs_array_name, subs.size()); int i = 0; for (const auto& pair : subs) { @@ -421,6 +427,8 @@ void WriteMessage(upb::MessageDefPtr message, const DefPoolPair& pools, output(" })\n"); } output("};\n\n"); + output("const upb_MiniTable* $0 = &$1;\n", MessagePtrName(message), + MessageInitName(message)); } void WriteEnum(upb::EnumDefPtr e, Output& output) { @@ -492,7 +500,7 @@ void WriteExtension(upb::FieldDefPtr ext, const DefPoolPair& pools, Output& output) { output("$0,\n", FieldInitializer(pools, ext)); output(" &$0,\n", MessageInitName(ext.containing_type())); - output(" $0,\n", GetSub(ext)); + output(" $0,\n", GetSub(ext, true)); } int WriteExtensions(const DefPoolPair& pools, upb::FileDefPtr file, @@ -571,6 +579,7 @@ void WriteMiniTableHeader(const DefPoolPair& pools, upb::FileDefPtr file, for (auto message : this_file_messages) { output("extern const upb_MiniTable $0;\n", MessageInitName(message)); + output("extern const upb_MiniTable* $0;\n", MessagePtrName(message)); } for (auto ext : this_file_exts) { output("extern const upb_MiniTableExtension $0;\n", ExtensionLayout(ext));