diff --git a/rust/cpp_kernel/cpp_api.h b/rust/cpp_kernel/cpp_api.h index 1db5445956..bc80df8222 100644 --- a/rust/cpp_kernel/cpp_api.h +++ b/rust/cpp_kernel/cpp_api.h @@ -34,8 +34,6 @@ #define GOOGLE_PROTOBUF_RUST_CPP_KERNEL_CPP_H__ #include -#include -#include #include "google/protobuf/message.h" @@ -73,6 +71,16 @@ inline SerializedData SerializeMsg(const google::protobuf::Message* msg) { return SerializedData(static_cast(bytes), len); } +// Represents an ABI-stable version of &[u8]/string_view (borrowed slice of +// bytes) for FFI use only. +struct PtrAndLen { + /// Borrows the memory. + const char* ptr; + size_t len; + + PtrAndLen(const char* ptr, size_t len) : ptr(ptr), len(len) {} +}; + } // namespace rust_internal } // namespace protobuf } // namespace google diff --git a/rust/shared.rs b/rust/shared.rs index 9e19e3a535..ed69760fdb 100644 --- a/rust/shared.rs +++ b/rust/shared.rs @@ -52,3 +52,12 @@ impl fmt::Display for ParseError { write!(f, "Couldn't deserialize given bytes into a proto") } } + +/// Represents an ABI-stable version of &[u8]/string_view (a borrowed slice of +/// bytes) for FFI use only. +#[repr(C)] +pub struct PtrAndLen { + /// Borrows the memory. + pub ptr: *const u8, + pub len: usize, +} diff --git a/rust/test/cpp/accessors_test.rs b/rust/test/cpp/accessors_test.rs index c43f49a410..91384e3391 100644 --- a/rust/test/cpp/accessors_test.rs +++ b/rust/test/cpp/accessors_test.rs @@ -53,3 +53,15 @@ fn test_optional_bool_accessors() { msg.optional_bool_set(None); assert_eq!(msg.optional_bool(), None); } + +#[test] +fn test_optional_bytes_accessors() { + let mut msg = unittest_proto::TestAllTypes::new(); + assert_eq!(msg.optional_bytes(), None); + + msg.optional_bytes_set(Some(b"accessors_test")); + assert_eq!(msg.optional_bytes().unwrap(), b"accessors_test"); + + msg.optional_bytes_set(None); + assert_eq!(msg.optional_bytes(), None); +} diff --git a/src/google/protobuf/compiler/rust/generator.cc b/src/google/protobuf/compiler/rust/generator.cc index e5bc5e072d..6e8733a3fc 100644 --- a/src/google/protobuf/compiler/rust/generator.cc +++ b/src/google/protobuf/compiler/rust/generator.cc @@ -45,6 +45,7 @@ #include "google/protobuf/compiler/cpp/names.h" #include "google/protobuf/compiler/rust/upb_kernel.h" #include "google/protobuf/descriptor.h" +#include "google/protobuf/descriptor.pb.h" #include "google/protobuf/io/printer.h" namespace google { @@ -100,7 +101,7 @@ std::string GetFileExtensionForKernel(Kernel kernel) { case Kernel::kCpp: return ".c.pb.rs"; } - ABSL_LOG(FATAL) << "Unknown kernel type: "; + ABSL_LOG(FATAL) << "Unknown kernel type: " << static_cast(kernel); return ""; } @@ -119,16 +120,22 @@ std::string GetAccessorThunkName( bool IsSupportedFieldType(const FieldDescriptor* field) { return !field->is_repeated() && - (field->cpp_type() == FieldDescriptor::CPPTYPE_BOOL || - field->cpp_type() == FieldDescriptor::CPPTYPE_INT64); + // We do not support [ctype=FOO] (used to set the field type in C++ to + // cord or string_piece) in V0 API. + !field->options().has_ctype() && + (field->type() == FieldDescriptor::TYPE_BOOL || + field->type() == FieldDescriptor::TYPE_INT64 || + field->type() == FieldDescriptor::TYPE_BYTES); } -std::string PrimitiveRsTypeName(const FieldDescriptor* field) { - switch (field->cpp_type()) { - case FieldDescriptor::CPPTYPE_INT64: - return "i64"; - case FieldDescriptor::CPPTYPE_BOOL: +absl::string_view PrimitiveRsTypeName(const FieldDescriptor* field) { + switch (field->type()) { + case FieldDescriptor::TYPE_BOOL: return "bool"; + case FieldDescriptor::TYPE_INT64: + return "i64"; + case FieldDescriptor::TYPE_BYTES: + return "&[u8]"; default: break; } @@ -136,6 +143,26 @@ std::string PrimitiveRsTypeName(const FieldDescriptor* field) { return ""; } +void EmitGetterExpr(const FieldDescriptor* field, google::protobuf::io::Printer& p, + absl::string_view underscore_delimited_full_name) { + std::string thunk_name = + GetAccessorThunkName(field, "get", underscore_delimited_full_name); + switch (field->type()) { + case FieldDescriptor::TYPE_BYTES: + p.Emit({{"getter_thunk_name", thunk_name}}, + R"rs( + let val = unsafe { $getter_thunk_name$(self.msg) }; + Some(unsafe { ::__std::slice::from_raw_parts(val.ptr, val.len) }) + )rs"); + return; + default: + p.Emit({{"getter_thunk_name", thunk_name}}, + R"rs( + Some(unsafe { $getter_thunk_name$(self.msg) }) + )rs"); + } +} + void GenerateAccessorFns(const Descriptor* msg, google::protobuf::io::Printer& p, absl::string_view underscore_delimited_full_name) { for (int i = 0; i < msg->field_count(); ++i) { @@ -151,10 +178,23 @@ void GenerateAccessorFns(const Descriptor* msg, google::protobuf::io::Printer& p GetAccessorThunkName(field, "has", underscore_delimited_full_name)}, {"getter_thunk_name", - GetAccessorThunkName(field, "", underscore_delimited_full_name)}, + GetAccessorThunkName(field, "get", + underscore_delimited_full_name)}, + {"getter_expr", + [&] { EmitGetterExpr(field, p, underscore_delimited_full_name); }}, {"setter_thunk_name", GetAccessorThunkName(field, "set", underscore_delimited_full_name)}, + {"setter_args", + [&] { + switch (field->type()) { + case FieldDescriptor::TYPE_BYTES: + p.Emit("val.as_ptr(), val.len()"); + return; + default: + p.Emit("val"); + } + }}, {"clearer_thunk_name", GetAccessorThunkName(field, "clear", underscore_delimited_full_name)}, @@ -164,11 +204,11 @@ void GenerateAccessorFns(const Descriptor* msg, google::protobuf::io::Printer& p if !unsafe { $hazzer_thunk_name$(self.msg) } { return None; } - Some(unsafe { $getter_thunk_name$(self.msg) }) + $getter_expr$ } pub fn $field_name$_set(&mut self, val: Option<$FieldType$>) { match val { - Some(val) => unsafe { $setter_thunk_name$(self.msg, val) }, + Some(val) => unsafe { $setter_thunk_name$(self.msg, $setter_args$) }, None => unsafe { $clearer_thunk_name$(self.msg) }, } } @@ -184,25 +224,47 @@ void GenerateAccessorThunkRsDeclarations( if (!IsSupportedFieldType(field)) { continue; } + absl::string_view type_name = PrimitiveRsTypeName(field); p.Emit( { - {"FieldType", PrimitiveRsTypeName(field)}, + {"FieldType", type_name}, + {"GetterReturnType", + [&] { + switch (field->type()) { + case FieldDescriptor::TYPE_BYTES: + p.Emit("::__pb::PtrAndLen"); + return; + default: + p.Emit(type_name); + } + }}, {"hazzer_thunk_name", GetAccessorThunkName(field, "has", underscore_delimited_full_name)}, {"getter_thunk_name", - GetAccessorThunkName(field, "", underscore_delimited_full_name)}, + GetAccessorThunkName(field, "get", + underscore_delimited_full_name)}, {"setter_thunk_name", GetAccessorThunkName(field, "set", underscore_delimited_full_name)}, + {"setter_params", + [&] { + switch (field->type()) { + case FieldDescriptor::TYPE_BYTES: + p.Emit("val: *const u8, len: usize"); + return; + default: + p.Emit({{"type_name", type_name}}, "val: $type_name$"); + } + }}, {"clearer_thunk_name", GetAccessorThunkName(field, "clear", underscore_delimited_full_name)}, }, R"rs( fn $hazzer_thunk_name$(raw_msg: ::__std::ptr::NonNull) -> bool; - fn $getter_thunk_name$(raw_msg: ::__std::ptr::NonNull) -> $FieldType$; - fn $setter_thunk_name$(raw_msg: ::__std::ptr::NonNull, val: $FieldType$); + fn $getter_thunk_name$(raw_msg: ::__std::ptr::NonNull) -> $GetterReturnType$;; + fn $setter_thunk_name$(raw_msg: ::__std::ptr::NonNull, $setter_params$); fn $clearer_thunk_name$(raw_msg: ::__std::ptr::NonNull); )rs"); } @@ -216,16 +278,60 @@ void GenerateAccessorThunksCcDefinitions( if (!IsSupportedFieldType(field)) { continue; } + const char* type_name = cpp::PrimitiveTypeName(field->cpp_type()); p.Emit( {{"field_name", field->name()}, - {"FieldType", cpp::PrimitiveTypeName(field->cpp_type())}, + {"FieldType", type_name}, + {"GetterReturnType", + [&] { + switch (field->type()) { + case FieldDescriptor::TYPE_BYTES: + p.Emit("::google::protobuf::rust_internal::PtrAndLen"); + return; + default: + p.Emit(type_name); + } + }}, {"namespace", cpp::Namespace(msg)}, {"hazzer_thunk_name", GetAccessorThunkName(field, "has", underscore_delimited_full_name)}, {"getter_thunk_name", - GetAccessorThunkName(field, "", underscore_delimited_full_name)}, + GetAccessorThunkName(field, "get", underscore_delimited_full_name)}, + {"getter_body", + [&] { + switch (field->type()) { + case FieldDescriptor::TYPE_BYTES: + p.Emit({{"field_name", field->name()}}, R"cc( + absl::string_view val = msg->$field_name$(); + return google::protobuf::rust_internal::PtrAndLen(val.data(), val.size()); + )cc"); + return; + default: + p.Emit(R"cc(return msg->$field_name$();)cc"); + } + }}, {"setter_thunk_name", GetAccessorThunkName(field, "set", underscore_delimited_full_name)}, + {"setter_params", + [&] { + switch (field->type()) { + case FieldDescriptor::TYPE_BYTES: + p.Emit("const char* ptr, size_t size"); + return; + default: + p.Emit({{"type_name", type_name}}, "$type_name$ val"); + } + }}, + {"setter_args", + [&] { + switch (field->type()) { + case FieldDescriptor::TYPE_BYTES: + p.Emit("absl::string_view(ptr, size)"); + return; + default: + p.Emit("val"); + } + }}, {"clearer_thunk_name", GetAccessorThunkName(field, "clear", underscore_delimited_full_name)}}, @@ -234,11 +340,11 @@ void GenerateAccessorThunksCcDefinitions( bool $hazzer_thunk_name$($namespace$::$Msg$* msg) { return msg->has_$field_name$(); } - $FieldType$ $getter_thunk_name$($namespace$::$Msg$* msg) { - return msg->$field_name$(); + $GetterReturnType$ $getter_thunk_name$($namespace$::$Msg$* msg) { + $getter_body$ } - void $setter_thunk_name$($namespace$::$Msg$* msg, $FieldType$ val) { - msg->set_$field_name$(val); + void $setter_thunk_name$($namespace$::$Msg$* msg, $setter_params$) { + msg->set_$field_name$($setter_args$); } void $clearer_thunk_name$($namespace$::$Msg$* msg) { msg->clear_$field_name$(); @@ -353,7 +459,7 @@ std::string GetKernelRustName(Kernel kernel) { case Kernel::kCpp: return "cpp"; } - ABSL_LOG(FATAL) << "Unknown kernel type: "; + ABSL_LOG(FATAL) << "Unknown kernel type: " << static_cast(kernel); return ""; }