diff --git a/rust/test/cpp/BUILD b/rust/test/cpp/BUILD index 91782bf48b..86976b077c 100644 --- a/rust/test/cpp/BUILD +++ b/rust/test/cpp/BUILD @@ -9,3 +9,19 @@ # To do that use: # * `rust_cc_proto_library` instead of `rust_proto_library`. # * `//rust:protobuf_cpp` instead of `//rust:protobuf``. + +load("@rules_rust//rust:defs.bzl", "rust_test") + +rust_test( + name = "accessors_test", + srcs = ["accessors_test.rs"], + tags = [ + # TODO(b/270274576): Enable testing on arm once we have a Rust Arm toolchain. + "not_build:arm", + # TODO(b/243126140): Enable tsan once we support sanitizers with Rust. + "notsan", + # TODO(b/243126140): Enable msan once we support sanitizers with Rust. + "nomsan", + ], + deps = ["//rust/test:unittest_cc_rust_proto"], +) diff --git a/rust/test/cpp/accessors_test.rs b/rust/test/cpp/accessors_test.rs new file mode 100644 index 0000000000..c43f49a410 --- /dev/null +++ b/rust/test/cpp/accessors_test.rs @@ -0,0 +1,55 @@ +// Protocol Buffers - Google's data interchange format +// Copyright 2023 Google Inc. All rights reserved. +// https://developers.google.com/protocol-buffers/ +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +/// Tests covering accessors for singular bool and int64 fields. + +#[test] +fn test_optional_int64_accessors() { + let mut msg = unittest_proto::TestAllTypes::new(); + assert_eq!(msg.optional_int64(), None); + + msg.optional_int64_set(Some(42)); + assert_eq!(msg.optional_int64(), Some(42)); + + msg.optional_int64_set(None); + assert_eq!(msg.optional_int64(), None); +} + +#[test] +fn test_optional_bool_accessors() { + let mut msg = unittest_proto::TestAllTypes::new(); + assert_eq!(msg.optional_bool(), None); + + msg.optional_bool_set(Some(true)); + assert_eq!(msg.optional_bool(), Some(true)); + + msg.optional_bool_set(None); + assert_eq!(msg.optional_bool(), None); +} diff --git a/rust/test/cpp/interop/main.rs b/rust/test/cpp/interop/main.rs index 618f98ebce..92174ff5b3 100644 --- a/rust/test/cpp/interop/main.rs +++ b/rust/test/cpp/interop/main.rs @@ -48,6 +48,13 @@ fn mutate_message_in_cpp() { assert_serializes_equally!(msg); } +#[test] +fn mutate_message_in_rust() { + let mut msg = unittest_proto::TestAllTypes::new(); + msg.optional_int64_set(Some(43)); + assert_serializes_equally!(msg); +} + #[test] fn deserialize_message_in_rust() { let serialized = unsafe { SerializeMutatedInstance() }; diff --git a/src/google/protobuf/compiler/rust/generator.cc b/src/google/protobuf/compiler/rust/generator.cc index fc100a5b81..e5bc5e072d 100644 --- a/src/google/protobuf/compiler/rust/generator.cc +++ b/src/google/protobuf/compiler/rust/generator.cc @@ -30,6 +30,7 @@ #include "google/protobuf/compiler/rust/generator.h" +#include #include #include #include @@ -38,7 +39,9 @@ #include "absl/strings/str_cat.h" #include "absl/strings/str_replace.h" #include "absl/strings/string_view.h" +#include "absl/strings/substitute.h" #include "absl/types/optional.h" +#include "google/protobuf/compiler/cpp/helpers.h" #include "google/protobuf/compiler/cpp/names.h" #include "google/protobuf/compiler/rust/upb_kernel.h" #include "google/protobuf/descriptor.h" @@ -107,15 +110,166 @@ std::string GetUnderscoreDelimitedFullName(const Descriptor* msg) { return result; } +std::string GetAccessorThunkName( + const FieldDescriptor* field, absl::string_view op, + absl::string_view underscore_delimited_full_name) { + return absl::Substitute("__rust_proto_thunk__$0_$1_$2", + underscore_delimited_full_name, op, field->name()); +} + +bool IsSupportedFieldType(const FieldDescriptor* field) { + return !field->is_repeated() && + (field->cpp_type() == FieldDescriptor::CPPTYPE_BOOL || + field->cpp_type() == FieldDescriptor::CPPTYPE_INT64); +} + +std::string PrimitiveRsTypeName(const FieldDescriptor* field) { + switch (field->cpp_type()) { + case FieldDescriptor::CPPTYPE_INT64: + return "i64"; + case FieldDescriptor::CPPTYPE_BOOL: + return "bool"; + default: + break; + } + ABSL_LOG(FATAL) << "Unsupported field type: " << field->type_name(); + return ""; +} + +void GenerateAccessorFns(const Descriptor* msg, google::protobuf::io::Printer& p, + absl::string_view underscore_delimited_full_name) { + for (int i = 0; i < msg->field_count(); ++i) { + const FieldDescriptor* field = msg->field(i); + if (!IsSupportedFieldType(field)) { + continue; + } + p.Emit( + { + {"field_name", field->name()}, + {"FieldType", PrimitiveRsTypeName(field)}, + {"hazzer_thunk_name", + GetAccessorThunkName(field, "has", + underscore_delimited_full_name)}, + {"getter_thunk_name", + GetAccessorThunkName(field, "", underscore_delimited_full_name)}, + {"setter_thunk_name", + GetAccessorThunkName(field, "set", + underscore_delimited_full_name)}, + {"clearer_thunk_name", + GetAccessorThunkName(field, "clear", + underscore_delimited_full_name)}, + }, + R"rs( + pub fn $field_name$(&self) -> Option<$FieldType$> { + if !unsafe { $hazzer_thunk_name$(self.msg) } { + return None; + } + Some(unsafe { $getter_thunk_name$(self.msg) }) + } + pub fn $field_name$_set(&mut self, val: Option<$FieldType$>) { + match val { + Some(val) => unsafe { $setter_thunk_name$(self.msg, val) }, + None => unsafe { $clearer_thunk_name$(self.msg) }, + } + } + )rs"); + } +} + +void GenerateAccessorThunkRsDeclarations( + const Descriptor* msg, google::protobuf::io::Printer& p, + std::string underscore_delimited_full_name) { + for (int i = 0; i < msg->field_count(); ++i) { + const FieldDescriptor* field = msg->field(i); + if (!IsSupportedFieldType(field)) { + continue; + } + p.Emit( + { + {"FieldType", PrimitiveRsTypeName(field)}, + {"hazzer_thunk_name", + GetAccessorThunkName(field, "has", + underscore_delimited_full_name)}, + {"getter_thunk_name", + GetAccessorThunkName(field, "", underscore_delimited_full_name)}, + {"setter_thunk_name", + GetAccessorThunkName(field, "set", + underscore_delimited_full_name)}, + {"clearer_thunk_name", + GetAccessorThunkName(field, "clear", + underscore_delimited_full_name)}, + }, + R"rs( + fn $hazzer_thunk_name$(raw_msg: ::__std::ptr::NonNull) -> bool; + fn $getter_thunk_name$(raw_msg: ::__std::ptr::NonNull) -> $FieldType$; + fn $setter_thunk_name$(raw_msg: ::__std::ptr::NonNull, val: $FieldType$); + fn $clearer_thunk_name$(raw_msg: ::__std::ptr::NonNull); + )rs"); + } +} + +void GenerateAccessorThunksCcDefinitions( + const Descriptor* msg, google::protobuf::io::Printer& p, + absl::string_view underscore_delimited_full_name) { + for (int i = 0; i < msg->field_count(); ++i) { + const FieldDescriptor* field = msg->field(i); + if (!IsSupportedFieldType(field)) { + continue; + } + p.Emit( + {{"field_name", field->name()}, + {"FieldType", cpp::PrimitiveTypeName(field->cpp_type())}, + {"namespace", cpp::Namespace(msg)}, + {"hazzer_thunk_name", + GetAccessorThunkName(field, "has", underscore_delimited_full_name)}, + {"getter_thunk_name", + GetAccessorThunkName(field, "", underscore_delimited_full_name)}, + {"setter_thunk_name", + GetAccessorThunkName(field, "set", underscore_delimited_full_name)}, + {"clearer_thunk_name", + GetAccessorThunkName(field, "clear", + underscore_delimited_full_name)}}, + R"cc( + extern "C" { + bool $hazzer_thunk_name$($namespace$::$Msg$* msg) { + return msg->has_$field_name$(); + } + $FieldType$ $getter_thunk_name$($namespace$::$Msg$* msg) { + return msg->$field_name$(); + } + void $setter_thunk_name$($namespace$::$Msg$* msg, $FieldType$ val) { + msg->set_$field_name$(val); + } + void $clearer_thunk_name$($namespace$::$Msg$* msg) { + msg->clear_$field_name$(); + } + } + )cc"); + } +} + void GenerateForCpp(const FileDescriptor* file, google::protobuf::io::Printer& p) { for (int i = 0; i < file->message_type_count(); ++i) { const Descriptor* msg = file->message_type(i); + std::string underscore_delimited_full_name = + GetUnderscoreDelimitedFullName(msg); p.Emit( { {"Msg", msg->name()}, - {"pkg_Msg", GetUnderscoreDelimitedFullName(msg)}, + {"pkg_Msg", underscore_delimited_full_name}, + {"accessor_fns", + [&] { + GenerateAccessorFns(file->message_type(i), p, + underscore_delimited_full_name); + }}, + {"accessor_thunks", + [&] { + GenerateAccessorThunkRsDeclarations( + file->message_type(i), p, underscore_delimited_full_name); + }}, }, R"rs( + #[allow(non_camel_case_types)] pub struct $Msg$ { msg: ::__std::ptr::NonNull, } @@ -141,12 +295,15 @@ void GenerateForCpp(const FileDescriptor* file, google::protobuf::io::Printer& p }; success.then_some(()).ok_or(::__pb::ParseError) } + $accessor_fns$ } extern "C" { fn __rust_proto_thunk__$pkg_Msg$__new() -> ::__std::ptr::NonNull; fn __rust_proto_thunk__$pkg_Msg$__serialize(raw_msg: ::__std::ptr::NonNull) -> ::__pb::SerializedData; fn __rust_proto_thunk__$pkg_Msg$__deserialize(raw_msg: ::__std::ptr::NonNull, data: ::__pb::SerializedData) -> bool; + + $accessor_thunks$ } )rs"); } @@ -155,11 +312,18 @@ void GenerateForCpp(const FileDescriptor* file, google::protobuf::io::Printer& p void GenerateThunksForCpp(const FileDescriptor* file, google::protobuf::io::Printer& p) { for (int i = 0; i < file->message_type_count(); ++i) { const Descriptor* msg = file->message_type(i); + std::string underscore_delimited_full_name = + GetUnderscoreDelimitedFullName(msg); p.Emit( { {"Msg", msg->name()}, {"pkg_Msg", GetUnderscoreDelimitedFullName(msg)}, {"namespace", cpp::Namespace(msg)}, + {"accessor_thunks", + [&] { + GenerateAccessorThunksCcDefinitions( + file->message_type(i), p, underscore_delimited_full_name); + }}, }, R"cc( extern "C" { @@ -175,6 +339,8 @@ void GenerateThunksForCpp(const FileDescriptor* file, google::protobuf::io::Prin google::protobuf::rust_internal::SerializedData data) { return msg->ParseFromArray(data.data, data.len); } + + $accessor_thunks$ } )cc"); } diff --git a/src/google/protobuf/compiler/rust/upb_kernel.cc b/src/google/protobuf/compiler/rust/upb_kernel.cc index 18fb0efdbd..69104b9f68 100644 --- a/src/google/protobuf/compiler/rust/upb_kernel.cc +++ b/src/google/protobuf/compiler/rust/upb_kernel.cc @@ -30,6 +30,7 @@ #include "google/protobuf/compiler/rust/upb_kernel.h" +#include #include #include #include