diff --git a/rust/test/upb/accessors_test.rs b/rust/test/upb/accessors_test.rs index ab909efae8..5e8cc717a2 100644 --- a/rust/test/upb/accessors_test.rs +++ b/rust/test/upb/accessors_test.rs @@ -40,3 +40,16 @@ fn test_optional_bool() { test_all_types.optional_bool_set(None); assert_eq!(test_all_types.optional_bool(), None); } + +#[test] +fn test_optional_int64() { + let mut test_all_types: unittest_proto::TestAllTypes = unittest_proto::TestAllTypes::new(); + test_all_types.optional_int64_set(Some(10)); + assert_eq!(test_all_types.optional_int64(), Some(10)); + + test_all_types.optional_int64_set(Some(-10)); + assert_eq!(test_all_types.optional_int64(), Some(-10)); + + test_all_types.optional_int64_set(None); + assert_eq!(test_all_types.optional_int64(), None); +} diff --git a/src/google/protobuf/compiler/rust/upb_kernel.cc b/src/google/protobuf/compiler/rust/upb_kernel.cc index 8b0896123c..18fb0efdbd 100644 --- a/src/google/protobuf/compiler/rust/upb_kernel.cc +++ b/src/google/protobuf/compiler/rust/upb_kernel.cc @@ -65,22 +65,35 @@ bool IsSupported(const FieldDescriptor* field) { return field->is_optional() && !field->is_repeated(); } -void GenBoolAccessors(const std::string& upb_msg_prefix, - const std::string& msg_name, const FieldDescriptor* field, - google::protobuf::io::Printer& p) { +absl::string_view RustTypeFromCppType(const FieldDescriptor::Type field_type) { + switch (field_type) { + case FieldDescriptor::Type::TYPE_BOOL: + return "bool"; + case FieldDescriptor::Type::TYPE_INT64: + return "i64"; + default: { + ABSL_LOG(FATAL) << "Unsupported field type: " << field_type; + } + } +} + +void GenScalarAccessors(const std::string& upb_msg_prefix, + const std::string& msg_name, + const FieldDescriptor* field, google::protobuf::io::Printer& p) { if (!IsSupported(field)) { return; } p.Emit({{"Msg", msg_name}, {"field_name", field->name()}, + {"data_type", RustTypeFromCppType(field->type())}, {"has_thunk", UpbThunkName(field, upb_msg_prefix, "_has_")}, {"getter_thunk", UpbThunkName(field, upb_msg_prefix, "_")}, {"setter_thunk", UpbThunkName(field, upb_msg_prefix, "_set_")}, {"clear_thunk", UpbThunkName(field, upb_msg_prefix, "_clear_")}}, R"rs( impl $Msg$ { - pub fn $field_name$(&self) -> Option { + pub fn $field_name$(&self) -> Option<$data_type$> { let field_present = unsafe { $has_thunk$(self.msg) }; if !field_present { return None; @@ -89,7 +102,7 @@ void GenBoolAccessors(const std::string& upb_msg_prefix, Some(value) } - pub fn $field_name$_set(&mut self, value: Option) { + pub fn $field_name$_set(&mut self, value: Option<$data_type$>) { match value { Some(value) => unsafe { $setter_thunk$(self.msg, value); }, None => unsafe { $clear_thunk$(self.msg); } @@ -98,11 +111,11 @@ void GenBoolAccessors(const std::string& upb_msg_prefix, } extern "C" { - fn $getter_thunk$(msg: ::__std::ptr::NonNull) -> bool; + fn $getter_thunk$(msg: ::__std::ptr::NonNull) -> $data_type$; fn $has_thunk$(msg: ::__std::ptr::NonNull) -> bool; fn $setter_thunk$( msg: ::__std::ptr::NonNull, - value: bool + value: $data_type$ ); fn $clear_thunk$(msg: ::__std::ptr::NonNull); } @@ -117,8 +130,9 @@ void GenFieldAccessors(const Descriptor* msg_descriptor, auto field = msg_descriptor->field(i); switch (field->type()) { + case FieldDescriptor::Type::TYPE_INT64: case FieldDescriptor::Type::TYPE_BOOL: - GenBoolAccessors(upb_msg_prefix, msg_descriptor->name(), field, p); + GenScalarAccessors(upb_msg_prefix, msg_descriptor->name(), field, p); break; default: // Not implemented type.