From e3432c283d5473a89474628523b8f21b914719f4 Mon Sep 17 00:00:00 2001 From: Hong Shin Date: Wed, 15 Nov 2023 14:33:41 -0800 Subject: [PATCH] Add submsg support for strings and bytes. msg.submsg().x() and msg.submsg().x_mut() should now be callable for strings and bytes. Main idea here was to return &[u8] for bytes (vs [u8]) and &ProtoStr instead of &str or &[u8] for strings. This CL also expunges IsSimpleScalar. PiperOrigin-RevId: 582809307 --- rust/test/nested.proto | 2 + rust/test/shared/simple_nested_test.rs | 4 + src/google/protobuf/compiler/rust/message.cc | 87 ++++++++++++-------- 3 files changed, 60 insertions(+), 33 deletions(-) diff --git a/rust/test/nested.proto b/rust/test/nested.proto index 6663a5ab1e..e6d4f29a7a 100644 --- a/rust/test/nested.proto +++ b/rust/test/nested.proto @@ -24,6 +24,8 @@ message Outer { optional sfixed32 sfixed32 = 11; optional sfixed64 sfixed64 = 12; optional bool bool = 13; + optional string string = 14; + optional bytes bytes = 15; message SuperInner { message DuperInner { diff --git a/rust/test/shared/simple_nested_test.rs b/rust/test/shared/simple_nested_test.rs index b5d6ba3f8a..d40a7381de 100644 --- a/rust/test/shared/simple_nested_test.rs +++ b/rust/test/shared/simple_nested_test.rs @@ -27,6 +27,8 @@ fn test_nested_views() { assert_that!(inner_msg.sfixed32(), eq(0)); assert_that!(inner_msg.sfixed64(), eq(0)); assert_that!(inner_msg.bool(), eq(false)); + assert_that!(*inner_msg.string().as_bytes(), empty()); + assert_that!(*inner_msg.bytes(), empty()); } #[test] @@ -48,6 +50,8 @@ fn test_nested_muts() { assert_that!(inner_msg.sfixed32(), eq(0)); assert_that!(inner_msg.sfixed64(), eq(0)); assert_that!(inner_msg.bool(), eq(false)); + assert_that!(*inner_msg.string().as_bytes(), empty()); + assert_that!(*inner_msg.bytes(), empty()); } #[test] diff --git a/src/google/protobuf/compiler/rust/message.cc b/src/google/protobuf/compiler/rust/message.cc index c70af33c30..a3968b6e52 100644 --- a/src/google/protobuf/compiler/rust/message.cc +++ b/src/google/protobuf/compiler/rust/message.cc @@ -9,6 +9,7 @@ #include "absl/log/absl_check.h" #include "absl/log/absl_log.h" +#include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" #include "google/protobuf/compiler/cpp/helpers.h" #include "google/protobuf/compiler/cpp/names.h" @@ -163,46 +164,66 @@ void MessageDrop(Context msg) { )rs"); } -// TODO: deferring on strings and bytes for now, eventually this -// check will go away as we support more than just simple scalars -bool IsSimpleScalar(FieldDescriptor::Type type) { - return type == FieldDescriptor::TYPE_DOUBLE || - type == FieldDescriptor::TYPE_FLOAT || - type == FieldDescriptor::TYPE_INT32 || - type == FieldDescriptor::TYPE_INT64 || - type == FieldDescriptor::TYPE_UINT32 || - type == FieldDescriptor::TYPE_UINT64 || - type == FieldDescriptor::TYPE_SINT32 || - type == FieldDescriptor::TYPE_SINT64 || - type == FieldDescriptor::TYPE_FIXED32 || - type == FieldDescriptor::TYPE_FIXED64 || - type == FieldDescriptor::TYPE_SFIXED32 || - type == FieldDescriptor::TYPE_SFIXED64 || - type == FieldDescriptor::TYPE_BOOL; -} - void GetterForViewOrMut(Context field, bool is_mut) { - // If we're dealing with a Mut, the getter must be supplied self.inner.msg() - // whereas a View has to be supplied self.msg - field.Emit( - { - {"field", field.desc().name()}, - {"getter_thunk", Thunk(field, "get")}, - {"self", is_mut ? "self.inner.msg()" : "self.msg"}, - {"Scalar", PrimitiveRsTypeName(field.desc())}, - }, - R"rs( - pub fn r#$field$(&self) -> $Scalar$ { - unsafe { $getter_thunk$($self$) } - } - )rs"); + auto fieldName = field.desc().name(); + auto fieldType = field.desc().type(); + auto getter_thunk = Thunk(field, "get"); + // If we're dealing with a Mut, the getter must be supplied + // self.inner.msg() whereas a View has to be supplied self.msg + auto self = is_mut ? "self.inner.msg()" : "self.msg"; + auto returnType = PrimitiveRsTypeName(field.desc()); + + if (fieldType == FieldDescriptor::TYPE_STRING) { + field.Emit( + { + {"field", fieldName}, + {"self", self}, + {"getter_thunk", getter_thunk}, + {"ReturnType", returnType}, + }, + R"rs( + pub fn r#$field$(&self) -> &$ReturnType$ { + let s = unsafe { $getter_thunk$($self$).as_ref() }; + unsafe { __pb::ProtoStr::from_utf8_unchecked(s) } + } + )rs"); + } else if (fieldType == FieldDescriptor::TYPE_BYTES) { + field.Emit( + { + {"field", fieldName}, + {"self", self}, + {"getter_thunk", getter_thunk}, + {"ReturnType", returnType}, + }, + R"rs( + pub fn r#$field$(&self) -> &$ReturnType$ { + unsafe { $getter_thunk$($self$).as_ref() } + } + )rs"); + } else { + field.Emit({{"field", fieldName}, + {"getter_thunk", getter_thunk}, + {"self", self}, + {"ReturnType", returnType}}, + R"rs( + pub fn r#$field$(&self) -> $ReturnType$ { + unsafe { $getter_thunk$($self$) } + } + )rs"); + } } void AccessorsForViewOrMut(Context msg, bool is_mut) { for (int i = 0; i < msg.desc().field_count(); ++i) { auto field = msg.WithDesc(*msg.desc().field(i)); if (field.desc().is_repeated()) continue; - if (!IsSimpleScalar(field.desc().type())) continue; + // TODO - add cord support + if (field.desc().options().has_ctype()) continue; + // TODO + if (field.desc().type() == FieldDescriptor::TYPE_MESSAGE || + field.desc().type() == FieldDescriptor::TYPE_ENUM || + field.desc().type() == FieldDescriptor::TYPE_GROUP) + continue; GetterForViewOrMut(field, is_mut); msg.printer().PrintRaw("\n"); }