From c6f6a3291e32fefb65de89b7e56e359034310ee2 Mon Sep 17 00:00:00 2001 From: Protobuf Team Bot Date: Wed, 3 Apr 2024 07:03:14 -0700 Subject: [PATCH] Add Length-Delimited Encode and Decode functions to upb. PiperOrigin-RevId: 621510731 --- upb/message/test.cc | 22 ++++++-- upb/test/BUILD | 19 +++++++ upb/test/length_delimited_test.cc | 86 +++++++++++++++++++++++++++++++ upb/wire/decode.c | 39 +++++++++++++- upb/wire/decode.h | 10 +++- upb/wire/encode.c | 32 +++++++++--- upb/wire/encode.h | 7 +++ 7 files changed, 201 insertions(+), 14 deletions(-) create mode 100644 upb/test/length_delimited_test.cc diff --git a/upb/message/test.cc b/upb/message/test.cc index 433632db14..f80e808768 100644 --- a/upb/message/test.cc +++ b/upb/message/test.cc @@ -588,7 +588,7 @@ TEST(MessageTest, Freeze) { // // static void DecodeEncodeArbitrarySchemaAndPayload( // const upb::fuzz::MiniTableFuzzInput& input, std::string_view proto_payload, -// int decode_options, int encode_options) { +// int decode_options, int encode_options, bool length_delimited = false) { // // Lexan does not have setenv // #ifndef _MSC_VER // setenv("FUZZTEST_STACK_LIMIT", "262144", 1); @@ -605,11 +605,25 @@ TEST(MessageTest, Freeze) { // upb::fuzz::BuildMiniTable(input, &exts, arena.ptr()); // if (!mini_table) return; // upb_Message* msg = upb_Message_New(mini_table, arena.ptr()); -// upb_Decode(proto_payload.data(), proto_payload.size(), msg, mini_table, exts, -// decode_options, arena.ptr()); +// if (length_delimited) { +// size_t num_bytes_read = 0; +// upb_DecodeStatus status = upb_DecodeLengthDelimited( +// proto_payload.data(), proto_payload.size(), msg, &num_bytes_read, +// mini_table, exts, decode_options, arena.ptr()); +// ASSERT_TRUE(status != kUpb_DecodeStatus_Ok || +// num_bytes_read <= proto_payload.size()); +// } else { +// upb_Decode(proto_payload.data(), proto_payload.size(), msg, mini_table, +// exts, decode_options, arena.ptr()); +// } // char* ptr; // size_t size; -// upb_Encode(msg, mini_table, encode_options, arena.ptr(), &ptr, &size); +// if (length_delimited) { +// upb_EncodeLengthDelimited(msg, mini_table, encode_options, arena.ptr(), +// &ptr, &size); +// } else { +// upb_Encode(msg, mini_table, encode_options, arena.ptr(), &ptr, &size); +// } // } // FUZZ_TEST(FuzzTest, DecodeEncodeArbitrarySchemaAndPayload); // diff --git a/upb/test/BUILD b/upb/test/BUILD index 24ef674310..094194ae21 100644 --- a/upb/test/BUILD +++ b/upb/test/BUILD @@ -201,6 +201,25 @@ cc_test( ], ) +cc_test( + name = "length_delimited_test", + srcs = ["length_delimited_test.cc"], + copts = UPB_DEFAULT_CPPOPTS, + deps = [ + ":test_messages_proto2_upb_minitable", + ":test_messages_proto2_upb_proto", + "//upb:base", + "//upb:mem", + "//upb:message", + "//upb:message_compare", + "//upb:mini_table", + "//upb:wire", + "//upb/mem:internal", + "@com_google_googletest//:gtest", + "@com_google_googletest//:gtest_main", + ], +) + cc_test( name = "test_cpp", srcs = ["test_cpp.cc"], diff --git a/upb/test/length_delimited_test.cc b/upb/test/length_delimited_test.cc new file mode 100644 index 0000000000..0c2f38a7f1 --- /dev/null +++ b/upb/test/length_delimited_test.cc @@ -0,0 +1,86 @@ + +#include +#include +#include +#include + +#include +#include "google/protobuf/test_messages_proto2.upb.h" +#include "google/protobuf/test_messages_proto2.upb_minitable.h" +#include "upb/base/string_view.h" +#include "upb/base/upcast.h" +#include "upb/mem/arena.h" +#include "upb/message/compare.h" +#include "upb/mini_table/message.h" +#include "upb/wire/decode.h" +#include "upb/wire/encode.h" + +namespace { + +static const upb_MiniTable* kTestMiniTable = + &protobuf_0test_0messages__proto2__TestAllTypesProto2_msg_init; + +static void TestEncodeDecodeRoundTrip( + upb_Arena* arena, + std::vector msgs) { + // Encode all of the messages and put their serializations contiguously. + std::string s; + for (auto msg : msgs) { + char* buf; + size_t size; + ASSERT_TRUE(upb_EncodeLengthDelimited(UPB_UPCAST(msg), kTestMiniTable, 0, + arena, &buf, + &size) == kUpb_EncodeStatus_Ok); + ASSERT_GT(size, 0); // Even empty messages are 1 byte in this encoding. + s.append(std::string(buf, size)); + } + + // Now decode all of the messages contained in the contiguous block. + std::vector decoded; + while (!s.empty()) { + protobuf_test_messages_proto2_TestAllTypesProto2* msg = + protobuf_test_messages_proto2_TestAllTypesProto2_new(arena); + size_t num_bytes_read; + ASSERT_TRUE(upb_DecodeLengthDelimited( + s.data(), s.length(), UPB_UPCAST(msg), &num_bytes_read, + kTestMiniTable, nullptr, 0, arena) == kUpb_DecodeStatus_Ok); + ASSERT_GT(num_bytes_read, 0); + decoded.push_back(msg); + s = s.substr(num_bytes_read); + } + + // Make sure that the values round tripped correctly. + ASSERT_EQ(msgs.size(), decoded.size()); + for (size_t i = 0; i < msgs.size(); ++i) { + ASSERT_TRUE(upb_Message_IsEqual(UPB_UPCAST(msgs[i]), UPB_UPCAST(decoded[i]), + kTestMiniTable, 0)); + } +} + +TEST(LengthDelimitedTest, OneEmptyMessage) { + upb_Arena* arena = upb_Arena_New(); + protobuf_test_messages_proto2_TestAllTypesProto2* msg = + protobuf_test_messages_proto2_TestAllTypesProto2_new(arena); + TestEncodeDecodeRoundTrip(arena, {msg}); + upb_Arena_Free(arena); +} + +TEST(LengthDelimitedTest, AFewMessages) { + upb_Arena* arena = upb_Arena_New(); + protobuf_test_messages_proto2_TestAllTypesProto2* a = + protobuf_test_messages_proto2_TestAllTypesProto2_new(arena); + protobuf_test_messages_proto2_TestAllTypesProto2* b = + protobuf_test_messages_proto2_TestAllTypesProto2_new(arena); + protobuf_test_messages_proto2_TestAllTypesProto2* c = + protobuf_test_messages_proto2_TestAllTypesProto2_new(arena); + + protobuf_test_messages_proto2_TestAllTypesProto2_set_optional_bool(a, true); + protobuf_test_messages_proto2_TestAllTypesProto2_set_optional_int32(b, 1); + protobuf_test_messages_proto2_TestAllTypesProto2_set_oneof_string( + c, upb_StringView_FromString("string")); + + TestEncodeDecodeRoundTrip(arena, {a, b, c}); + upb_Arena_Free(arena); +} + +} // namespace diff --git a/upb/wire/decode.c b/upb/wire/decode.c index 19ba506a41..4406c03c7a 100644 --- a/upb/wire/decode.c +++ b/upb/wire/decode.c @@ -1366,7 +1366,7 @@ static upb_DecodeStatus upb_Decoder_Decode(upb_Decoder* const decoder, } upb_DecodeStatus upb_Decode(const char* buf, size_t size, upb_Message* msg, - const upb_MiniTable* m, + const upb_MiniTable* mt, const upb_ExtensionRegistry* extreg, int options, upb_Arena* arena) { UPB_ASSERT(!upb_Message_IsFrozen(msg)); @@ -1391,7 +1391,42 @@ upb_DecodeStatus upb_Decode(const char* buf, size_t size, upb_Message* msg, // (particularly parent_or_count). UPB_PRIVATE(_upb_Arena_SwapIn)(&decoder.arena, arena); - return upb_Decoder_Decode(&decoder, buf, msg, m, arena); + return upb_Decoder_Decode(&decoder, buf, msg, mt, arena); +} + +upb_DecodeStatus upb_DecodeLengthDelimited(const char* buf, size_t size, + upb_Message* msg, + size_t* num_bytes_read, + const upb_MiniTable* mt, + const upb_ExtensionRegistry* extreg, + int options, upb_Arena* arena) { + // To avoid needing to make a Decoder just to decode the initial length, + // hand-decode the leading varint for the message length here. + uint64_t msg_len = 0; + for (size_t i = 0;; ++i) { + if (i >= size || i > 9) { + return kUpb_DecodeStatus_Malformed; + } + uint64_t b = *buf; + buf++; + msg_len += (b & 0x7f) << (i * 7); + if ((b & 0x80) == 0) { + *num_bytes_read = i + 1 + msg_len; + break; + } + } + + // If the total number of bytes we would read (= the bytes from the varint + // plus however many bytes that varint says we should read) is larger then the + // input buffer then error as malformed. + if (*num_bytes_read > size) { + return kUpb_DecodeStatus_Malformed; + } + if (msg_len > INT32_MAX) { + return kUpb_DecodeStatus_Malformed; + } + + return upb_Decode(buf, msg_len, msg, mt, extreg, options, arena); } #undef OP_FIXPCK_LG2 diff --git a/upb/wire/decode.h b/upb/wire/decode.h index 68bf63f356..afd4bbb5e3 100644 --- a/upb/wire/decode.h +++ b/upb/wire/decode.h @@ -129,10 +129,18 @@ typedef enum { } upb_DecodeStatus; UPB_API upb_DecodeStatus upb_Decode(const char* buf, size_t size, - upb_Message* msg, const upb_MiniTable* l, + upb_Message* msg, const upb_MiniTable* mt, const upb_ExtensionRegistry* extreg, int options, upb_Arena* arena); +// Same as upb_Decode but with a varint-encoded length prepended. +// On success 'num_bytes_read' will be set to the how many bytes were read, +// on failure the contents of num_bytes_read is undefined. +UPB_API upb_DecodeStatus upb_DecodeLengthDelimited( + const char* buf, size_t size, upb_Message* msg, size_t* num_bytes_read, + const upb_MiniTable* mt, const upb_ExtensionRegistry* extreg, int options, + upb_Arena* arena); + #ifdef __cplusplus } /* extern "C" */ #endif diff --git a/upb/wire/encode.c b/upb/wire/encode.c index 671de1afdb..7a35afbf3e 100644 --- a/upb/wire/encode.c +++ b/upb/wire/encode.c @@ -607,14 +607,18 @@ static void encode_message(upb_encstate* e, const upb_Message* msg, static upb_EncodeStatus upb_Encoder_Encode(upb_encstate* const encoder, const upb_Message* const msg, const upb_MiniTable* const l, - char** const buf, - size_t* const size) { + char** const buf, size_t* const size, + bool prepend_len) { // Unfortunately we must continue to perform hackery here because there are // code paths which blindly copy the returned pointer without bothering to // check for errors until much later (b/235839510). So we still set *buf to // NULL on error and we still set it to non-NULL on a successful empty result. if (UPB_SETJMP(encoder->err) == 0) { - encode_message(encoder, msg, l, size); + size_t encoded_msg_size; + encode_message(encoder, msg, l, &encoded_msg_size); + if (prepend_len) { + encode_varint(encoder, encoded_msg_size); + } *size = encoder->limit - encoder->ptr; if (*size == 0) { static char ch; @@ -633,9 +637,10 @@ static upb_EncodeStatus upb_Encoder_Encode(upb_encstate* const encoder, return encoder->status; } -upb_EncodeStatus upb_Encode(const upb_Message* msg, const upb_MiniTable* l, - int options, upb_Arena* arena, char** buf, - size_t* size) { +static upb_EncodeStatus _upb_Encode(const upb_Message* msg, + const upb_MiniTable* l, int options, + upb_Arena* arena, char** buf, size_t* size, + bool prepend_len) { upb_encstate e; unsigned depth = (unsigned)options >> 16; @@ -648,5 +653,18 @@ upb_EncodeStatus upb_Encode(const upb_Message* msg, const upb_MiniTable* l, e.options = options; _upb_mapsorter_init(&e.sorter); - return upb_Encoder_Encode(&e, msg, l, buf, size); + return upb_Encoder_Encode(&e, msg, l, buf, size, prepend_len); +} + +upb_EncodeStatus upb_Encode(const upb_Message* msg, const upb_MiniTable* l, + int options, upb_Arena* arena, char** buf, + size_t* size) { + return _upb_Encode(msg, l, options, arena, buf, size, false); +} + +upb_EncodeStatus upb_EncodeLengthDelimited(const upb_Message* msg, + const upb_MiniTable* l, int options, + upb_Arena* arena, char** buf, + size_t* size) { + return _upb_Encode(msg, l, options, arena, buf, size, true); } diff --git a/upb/wire/encode.h b/upb/wire/encode.h index fed261de51..d013041415 100644 --- a/upb/wire/encode.h +++ b/upb/wire/encode.h @@ -68,6 +68,13 @@ UPB_API upb_EncodeStatus upb_Encode(const upb_Message* msg, const upb_MiniTable* l, int options, upb_Arena* arena, char** buf, size_t* size); +// Encodes the message prepended by a varint of the serialized length. +UPB_API upb_EncodeStatus upb_EncodeLengthDelimited(const upb_Message* msg, + const upb_MiniTable* l, + int options, + upb_Arena* arena, char** buf, + size_t* size); + #ifdef __cplusplus } /* extern "C" */ #endif