diff --git a/upb/fuzz_test_util.cc b/upb/fuzz_test_util.cc index 3ca4e89f31..051444dcb3 100644 --- a/upb/fuzz_test_util.cc +++ b/upb/fuzz_test_util.cc @@ -85,16 +85,9 @@ void Builder::BuildMessages() { mini_tables_.reserve(input_->mini_descriptors.size()); for (const auto& d : input_->mini_descriptors) { upb_MiniTable* table; - if (d == "\n") { - // We special-case this input string, which is not a valid - // mini-descriptor, to mean message set. - table = - upb_MiniTable_BuildMessageSet(kUpb_MiniTablePlatform_Native, arena_); - } else { - table = - upb_MiniTable_Build(d.data(), d.size(), kUpb_MiniTablePlatform_Native, - arena_, status.ptr()); - } + table = + upb_MiniTable_Build(d.data(), d.size(), kUpb_MiniTablePlatform_Native, + arena_, status.ptr()); if (table) mini_tables_.push_back(table); } } diff --git a/upb/mini_table.c b/upb/mini_table.c index 1217299da8..99bb14487c 100644 --- a/upb/mini_table.c +++ b/upb/mini_table.c @@ -87,6 +87,7 @@ enum { kUpb_EncodedVersion_ExtensionV1 = '#', kUpb_EncodedVersion_MapV1 = '%', kUpb_EncodedVersion_MessageV1 = '$', + kUpb_EncodedVersion_MessageSetV1 = '&', }; char upb_ToBase92(int8_t ch) { @@ -231,6 +232,11 @@ char* upb_MtDataEncoder_EncodeMap(upb_MtDataEncoder* e, char* ptr, return upb_MtDataEncoder_PutField(e, ptr, value_type, 2, value_mod); } +char* upb_MtDataEncoder_EncodeMessageSet(upb_MtDataEncoder* e, char* ptr) { + (void)upb_MtDataEncoder_GetInternal(e, ptr); + return upb_MtDataEncoder_PutRaw(e, ptr, kUpb_EncodedVersion_MessageSetV1); +} + char* upb_MtDataEncoder_StartMessage(upb_MtDataEncoder* e, char* ptr, uint64_t msg_mod) { upb_MtDataEncoderInternal* in = upb_MtDataEncoder_GetInternal(e, ptr); @@ -1091,7 +1097,7 @@ static void upb_MiniTable_BuildMapEntry(upb_MtDecoder* d, static void upb_MtDecoder_ParseMap(upb_MtDecoder* d, const char* data, size_t len) { if (len < 2) { - upb_MtDecoder_ErrorFormat(d, "Invalid map encoding length: %zu", len); + upb_MtDecoder_ErrorFormat(d, "Invalid map encode length: %zu", len); UPB_UNREACHABLE(); } const upb_EncodedType e0 = upb_FromBase92(data[0]); @@ -1125,6 +1131,22 @@ static void upb_MtDecoder_ParseMap(upb_MtDecoder* d, const char* data, upb_MiniTable_BuildMapEntry(d, key_type, val_type, value_is_proto3_enum); } +static void upb_MtDecoder_ParseMessageSet(upb_MtDecoder* d, const char* data, + size_t len) { + if (len > 0) { + upb_MtDecoder_ErrorFormat(d, "Invalid message set encode length: %zu", len); + UPB_UNREACHABLE(); + } + + upb_MiniTable* ret = d->table; + ret->size = 0; + ret->field_count = 0; + ret->ext = kUpb_ExtMode_IsMessageSet; + ret->dense_below = 0; + ret->table_mask = -1; + ret->required_count = 0; +} + upb_MiniTable* upb_MiniTable_BuildWithBuf(const char* data, size_t len, upb_MiniTablePlatform platform, upb_Arena* arena, void** buf, @@ -1173,6 +1195,10 @@ upb_MiniTable* upb_MiniTable_BuildWithBuf(const char* data, size_t len, upb_MtDecoder_AssignOffsets(&decoder); break; + case kUpb_EncodedVersion_MessageSetV1: + upb_MtDecoder_ParseMessageSet(&decoder, data, len); + break; + default: upb_MtDecoder_ErrorFormat(&decoder, "Invalid message version: %c", vers); UPB_UNREACHABLE(); @@ -1184,20 +1210,6 @@ done: return decoder.table; } -upb_MiniTable* upb_MiniTable_BuildMessageSet(upb_MiniTablePlatform platform, - upb_Arena* arena) { - upb_MiniTable* ret = upb_Arena_Malloc(arena, sizeof(*ret)); - if (!ret) return NULL; - - ret->size = 0; - ret->field_count = 0; - ret->ext = kUpb_ExtMode_IsMessageSet; - ret->dense_below = 0; - ret->table_mask = -1; - ret->required_count = 0; - return ret; -} - static size_t upb_MiniTable_EnumSize(size_t count) { return sizeof(upb_MiniTable_Enum) + count * sizeof(uint32_t); } diff --git a/upb/mini_table.h b/upb/mini_table.h index 27a7471a2f..519af75eb8 100644 --- a/upb/mini_table.h +++ b/upb/mini_table.h @@ -131,6 +131,9 @@ char* upb_MtDataEncoder_EncodeMap(upb_MtDataEncoder* e, char* ptr, upb_FieldType key_type, upb_FieldType value_type, uint64_t value_mod); +// Encodes an entire mini descriptor for a message set. +char* upb_MtDataEncoder_EncodeMessageSet(upb_MtDataEncoder* e, char* ptr); + /** upb_MiniTable *************************************************************/ typedef enum { @@ -169,10 +172,6 @@ const char* upb_MiniTable_BuildExtension(const char* data, size_t len, upb_MiniTable_Sub sub, upb_Status* status); -// Special-case functions for MessageSet layout and map entries. -upb_MiniTable* upb_MiniTable_BuildMessageSet(upb_MiniTablePlatform platform, - upb_Arena* arena); - upb_MiniTable_Enum* upb_MiniTable_BuildEnum(const char* data, size_t len, upb_Arena* arena, upb_Status* status); diff --git a/upb/mini_table.hpp b/upb/mini_table.hpp index c3bfb8bc17..14cb4d87a2 100644 --- a/upb/mini_table.hpp +++ b/upb/mini_table.hpp @@ -97,6 +97,12 @@ class MtDataEncoder { }); } + bool EncodeMessageSet() { + return appender_([=](char* buf) { + return upb_MtDataEncoder_EncodeMessageSet(&encoder_, buf); + }); + } + const std::string& data() const { return appender_.data(); } private: diff --git a/upb/reflection/message_def.c b/upb/reflection/message_def.c index 7e0e9eb3b7..7e674a1ec9 100644 --- a/upb/reflection/message_def.c +++ b/upb/reflection/message_def.c @@ -131,8 +131,7 @@ bool _upb_MessageDef_IsValidExtensionNumber(const upb_MessageDef* m, int n) { return false; } -const google_protobuf_MessageOptions* upb_MessageDef_Options( - const upb_MessageDef* m) { +const google_protobuf_MessageOptions* upb_MessageDef_Options(const upb_MessageDef* m) { return m->opts; } @@ -306,24 +305,15 @@ const upb_OneofDef* upb_MessageDef_FindOneofByName(const upb_MessageDef* m, } bool upb_MessageDef_IsMapEntry(const upb_MessageDef* m) { - return google_protobuf_MessageOptions_map_entry(upb_MessageDef_Options(m)); + return google_protobuf_MessageOptions_map_entry(m->opts); } bool upb_MessageDef_IsMessageSet(const upb_MessageDef* m) { - return google_protobuf_MessageOptions_message_set_wire_format( - upb_MessageDef_Options(m)); + return google_protobuf_MessageOptions_message_set_wire_format(m->opts); } static upb_MiniTable* _upb_MessageDef_MakeMiniTable(upb_DefBuilder* ctx, const upb_MessageDef* m) { - if (google_protobuf_MessageOptions_message_set_wire_format(m->opts)) { - if (m->field_count > 0) { - _upb_DefBuilder_Errf(ctx, "invalid message set (%s)", m->full_name); - } - return upb_MiniTable_BuildMessageSet(kUpb_MiniTablePlatform_Native, - ctx->arena); - } - upb_StringView desc; bool ok = upb_MessageDef_MiniDescriptorEncode(m, ctx->tmp_arena, &desc); if (!ok) _upb_DefBuilder_OomErr(ctx); @@ -479,7 +469,6 @@ static bool _upb_MessageDef_EncodeMap(upb_DescState* s, const upb_MessageDef* m, ? kUpb_FieldModifier_IsClosedEnum : 0; - if (!_upb_DescState_Grow(s, a)) return false; s->ptr = upb_MtDataEncoder_EncodeMap(&s->e, s->ptr, key_type, val_type, val_mod); return true; @@ -494,7 +483,6 @@ static bool _upb_MessageDef_EncodeMessage(upb_DescState* s, if (!sorted) return false; } - if (!_upb_DescState_Grow(s, a)) return false; s->ptr = upb_MtDataEncoder_StartMessage(&s->e, s->ptr, _upb_MessageDef_Modifiers(m)); @@ -525,13 +513,25 @@ static bool _upb_MessageDef_EncodeMessage(upb_DescState* s, return true; } +static bool _upb_MessageDef_EncodeMessageSet(upb_DescState* s, + const upb_MessageDef* m, + upb_Arena* a) { + s->ptr = upb_MtDataEncoder_EncodeMessageSet(&s->e, s->ptr); + + return true; +} + bool upb_MessageDef_MiniDescriptorEncode(const upb_MessageDef* m, upb_Arena* a, upb_StringView* out) { upb_DescState s; _upb_DescState_Init(&s); + if (!_upb_DescState_Grow(&s, a)) return false; + if (upb_MessageDef_IsMapEntry(m)) { if (!_upb_MessageDef_EncodeMap(&s, m, a)) return false; + } else if (google_protobuf_MessageOptions_message_set_wire_format(m->opts)) { + if (!_upb_MessageDef_EncodeMessageSet(&s, m, a)) return false; } else { if (!_upb_MessageDef_EncodeMessage(&s, m, a)) return false; } @@ -596,6 +596,13 @@ static void create_msgdef(upb_DefBuilder* ctx, const char* prefix, m->fields = _upb_FieldDefs_New(ctx, n_field, fields, m->full_name, m, &m->is_sorted); + // Message Sets may not contain fields. + if (UPB_UNLIKELY(google_protobuf_MessageOptions_message_set_wire_format(m->opts))) { + if (UPB_UNLIKELY(n_field > 0)) { + _upb_DefBuilder_Errf(ctx, "invalid message set (%s)", m->full_name); + } + } + m->ext_range_count = n_ext_range; m->ext_ranges = _upb_ExtensionRanges_New(ctx, n_ext_range, ext_ranges, m); diff --git a/upbc/file_layout.cc b/upbc/file_layout.cc index f3c489c024..5f1ea660e7 100644 --- a/upbc/file_layout.cc +++ b/upbc/file_layout.cc @@ -285,7 +285,7 @@ void FilePlatformLayout::BuildExtensions(const protobuf::FileDescriptor* fd) { upb_MiniTable* FilePlatformLayout::MakeMiniTable( const protobuf::Descriptor* m) { if (m->options().message_set_wire_format()) { - return upb_MiniTable_BuildMessageSet(platform_, arena_.ptr()); + return MakeMessageSetMiniTable(m); } else if (m->options().map_entry()) { return MakeMapMiniTable(m); } else { @@ -317,6 +317,22 @@ upb_MiniTable* FilePlatformLayout::MakeMapMiniTable( return ret; } +upb_MiniTable* FilePlatformLayout::MakeMessageSetMiniTable( + const protobuf::Descriptor* m) { + upb::MtDataEncoder e; + e.EncodeMessageSet(); + + const absl::string_view str = e.data(); + upb::Status status; + upb_MiniTable* ret = upb_MiniTable_Build(str.data(), str.size(), platform_, + arena_.ptr(), status.ptr()); + if (!ret) { + fprintf(stderr, "Error building mini-table: %s\n", status.error_message()); + } + assert(ret); + return ret; +} + upb_MiniTable* FilePlatformLayout::MakeRegularMiniTable( const protobuf::Descriptor* m) { upb::MtDataEncoder e; diff --git a/upbc/file_layout.h b/upbc/file_layout.h index b448df6447..3dd809cba6 100644 --- a/upbc/file_layout.h +++ b/upbc/file_layout.h @@ -91,6 +91,7 @@ class FilePlatformLayout { void BuildExtensions(const protobuf::FileDescriptor* fd); upb_MiniTable* MakeMiniTable(const protobuf::Descriptor* m); upb_MiniTable* MakeMapMiniTable(const protobuf::Descriptor* m); + upb_MiniTable* MakeMessageSetMiniTable(const protobuf::Descriptor* m); upb_MiniTable* MakeRegularMiniTable(const protobuf::Descriptor* m); upb_MiniTable_Enum* MakeMiniTableEnum(const protobuf::EnumDescriptor* d); uint64_t GetMessageModifiers(const protobuf::Descriptor* m);