diff --git a/rust/cpp_kernel/BUILD b/rust/cpp_kernel/BUILD index 2438bfb79d..ef04dd41d4 100644 --- a/rust/cpp_kernel/BUILD +++ b/rust/cpp_kernel/BUILD @@ -14,6 +14,7 @@ cc_library( ":rust_alloc_for_cpp_api", # buildcleaner: keep "//src/google/protobuf", "//src/google/protobuf:protobuf_lite", + "@com_google_absl//absl/log:absl_check", "@com_google_absl//absl/log:absl_log", ], ) diff --git a/rust/cpp_kernel/cpp_api.h b/rust/cpp_kernel/cpp_api.h index 44f299feb5..8ee8ebc77d 100644 --- a/rust/cpp_kernel/cpp_api.h +++ b/rust/cpp_kernel/cpp_api.h @@ -10,10 +10,13 @@ #ifndef GOOGLE_PROTOBUF_RUST_CPP_KERNEL_CPP_H__ #define GOOGLE_PROTOBUF_RUST_CPP_KERNEL_CPP_H__ +#include #include +#include #include #include +#include "absl/log/absl_check.h" #include "absl/log/absl_log.h" #include "google/protobuf/message.h" #include "google/protobuf/message_lite.h" @@ -31,11 +34,11 @@ namespace rust_internal { // * The data were allocated using the Rust allocator. // extern "C" struct SerializedData { - // Owns the memory. - const char* data; + // Owns the memory, must be freed by Rust. + const uint8_t* data; size_t len; - SerializedData(const char* data, size_t len) : data(data), len(len) {} + SerializedData(const uint8_t* data, size_t len) : data(data), len(len) {} }; // Allocates memory using the current Rust global allocator. @@ -44,15 +47,21 @@ extern "C" struct SerializedData { extern "C" void* __pb_rust_alloc(size_t size, size_t align); inline bool SerializeMsg(const google::protobuf::MessageLite* msg, SerializedData* out) { + ABSL_DCHECK(msg->IsInitialized()); size_t len = msg->ByteSizeLong(); - void* bytes = __pb_rust_alloc(len, alignof(char)); + if (len > INT_MAX) { + ABSL_LOG(ERROR) << msg->GetTypeName() + << " exceeded maximum protobuf size of 2GB: " << len; + return false; + } + uint8_t* bytes = static_cast(__pb_rust_alloc(len, alignof(char))); if (bytes == nullptr) { ABSL_LOG(FATAL) << "Rust allocator failed to allocate memory."; } - if (!msg->SerializeToArray(bytes, static_cast(len))) { + if (!msg->SerializeWithCachedSizesToArray(bytes)) { return false; } - *out = SerializedData(static_cast(bytes), len); + *out = SerializedData(bytes, len); return true; }