diff --git a/rust/optional.rs b/rust/optional.rs index bdbc5ec4d2..2e576e114c 100644 --- a/rust/optional.rs +++ b/rust/optional.rs @@ -234,7 +234,7 @@ pub struct PresentField<'msg, T> where T: ProxiedWithPresence + ?Sized + 'msg, { - inner: T::PresentMutData<'msg>, + pub(crate) inner: T::PresentMutData<'msg>, } impl<'msg, T: ProxiedWithPresence + ?Sized + 'msg> Debug for PresentField<'msg, T> { @@ -312,7 +312,7 @@ pub struct AbsentField<'a, T> where T: ProxiedWithPresence + ?Sized + 'a, { - inner: T::AbsentMutData<'a>, + pub(crate) inner: T::AbsentMutData<'a>, } impl<'msg, T: ProxiedWithPresence + ?Sized + 'msg> Debug for AbsentField<'msg, T> { diff --git a/rust/proxied.rs b/rust/proxied.rs index e5c4bb0f5d..1bbf35e020 100644 --- a/rust/proxied.rs +++ b/rust/proxied.rs @@ -395,8 +395,8 @@ mod tests { impl SettableValue for Cow<'_, str> { fn set_on(self, _private: Private, mutator: Mut) { match self { - Cow::Owned(x) => x.set_on(Private, mutator), - Cow::Borrowed(x) => x.set_on(Private, mutator), + Cow::Owned(x) => >::set_on(x, Private, mutator), + Cow::Borrowed(x) => <&str as SettableValue>::set_on(x, Private, mutator), } } } diff --git a/rust/shared.rs b/rust/shared.rs index b6f8805729..6313eae80f 100644 --- a/rust/shared.rs +++ b/rust/shared.rs @@ -44,7 +44,7 @@ pub mod __public { pub use crate::proxied::{ Mut, MutProxy, Proxied, ProxiedWithPresence, SettableValue, View, ViewProxy, }; - pub use crate::string::{BytesMut, ProtoStr}; + pub use crate::string::{BytesMut, ProtoStr, ProtoStrMut}; } pub use __public::*; diff --git a/rust/string.rs b/rust/string.rs index 523a50846d..4ad4fbcd5f 100644 --- a/rust/string.rs +++ b/rust/string.rs @@ -35,7 +35,10 @@ use crate::__internal::{Private, PtrAndLen, RawMessage}; use crate::__runtime::{BytesAbsentMutData, BytesPresentMutData, InnerBytesMut}; use crate::macros::impl_forwarding_settable_value; -use crate::{Mut, MutProxy, Proxied, ProxiedWithPresence, SettableValue, View, ViewProxy}; +use crate::{ + AbsentField, FieldEntry, Mut, MutProxy, Optional, PresentField, Proxied, ProxiedWithPresence, + SettableValue, View, ViewProxy, +}; use std::borrow::Cow; use std::cmp::{Eq, Ord, Ordering, PartialEq, PartialOrd}; use std::convert::{AsMut, AsRef}; @@ -111,7 +114,7 @@ impl<'msg> BytesMut<'msg> { /// `BytesMut::clear` results in the accessor returning an empty string /// while `FieldEntry::clear` results in the non-empty default. /// - /// However, for a proto3 `bytes` that have implicit presence, there is no + /// However, for a proto3 `bytes` that has implicit presence, there is no /// distinction between these states: unset `bytes` is the same as empty /// `bytes` and the default is always the empty string. /// @@ -204,7 +207,7 @@ impl<'msg> MutProxy<'msg> for BytesMut<'msg> { } } -impl<'bytes> SettableValue<[u8]> for &'bytes [u8] { +impl SettableValue<[u8]> for &'_ [u8] { fn set_on(self, _private: Private, mutator: BytesMut<'_>) { // SAFETY: this is a `bytes` field with no restriction on UTF-8. unsafe { mutator.inner.set(self) } @@ -231,7 +234,7 @@ impl<'bytes> SettableValue<[u8]> for &'bytes [u8] { } } -impl<'a, const N: usize> SettableValue<[u8]> for &'a [u8; N] { +impl SettableValue<[u8]> for &'_ [u8; N] { // forward to `self[..]` impl_forwarding_settable_value!([u8], self => &self[..]); } @@ -427,6 +430,12 @@ impl<'msg> From<&'msg ProtoStr> for &'msg [u8] { } } +impl<'msg> From<&'msg str> for &'msg ProtoStr { + fn from(val: &'msg str) -> &'msg ProtoStr { + ProtoStr::from_str(val) + } +} + impl<'msg> TryFrom<&'msg ProtoStr> for &'msg str { type Error = Utf8Error; @@ -435,6 +444,14 @@ impl<'msg> TryFrom<&'msg ProtoStr> for &'msg str { } } +impl<'msg> TryFrom<&'msg [u8]> for &'msg ProtoStr { + type Error = Utf8Error; + + fn try_from(val: &'msg [u8]) -> Result<&'msg ProtoStr, Utf8Error> { + Ok(ProtoStr::from_str(std::str::from_utf8(val)?)) + } +} + impl fmt::Debug for ProtoStr { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { fmt::Debug::fmt(&Utf8Chunks::new(self.as_bytes()).debug(), f) @@ -455,7 +472,306 @@ impl fmt::Display for ProtoStr { } } -// TODO(b/285309330): Add `ProtoStrMut` +impl Hash for ProtoStr { + fn hash(&self, state: &mut H) { + self.as_bytes().hash(state) + } +} + +impl Eq for ProtoStr {} +impl Ord for ProtoStr { + fn cmp(&self, other: &ProtoStr) -> Ordering { + self.as_bytes().cmp(other.as_bytes()) + } +} + +impl Proxied for ProtoStr { + type View<'msg> = &'msg ProtoStr; + type Mut<'msg> = ProtoStrMut<'msg>; +} + +impl ProxiedWithPresence for ProtoStr { + type PresentMutData<'msg> = StrPresentMutData<'msg>; + type AbsentMutData<'msg> = StrAbsentMutData<'msg>; + + fn clear_present_field(present_mutator: Self::PresentMutData<'_>) -> Self::AbsentMutData<'_> { + StrAbsentMutData(present_mutator.0.clear()) + } + + fn set_absent_to_default(absent_mutator: Self::AbsentMutData<'_>) -> Self::PresentMutData<'_> { + StrPresentMutData(absent_mutator.0.set_absent_to_default()) + } +} + +impl<'msg> ViewProxy<'msg> for &'msg ProtoStr { + type Proxied = ProtoStr; + + fn as_view(&self) -> &ProtoStr { + self + } + + fn into_view<'shorter>(self) -> &'shorter ProtoStr + where + 'msg: 'shorter, + { + self + } +} + +/// Non-exported newtype for `ProxiedWithPresence::PresentData` +#[derive(Debug)] +pub struct StrPresentMutData<'msg>(BytesPresentMutData<'msg>); + +impl<'msg> ViewProxy<'msg> for StrPresentMutData<'msg> { + type Proxied = ProtoStr; + + fn as_view(&self) -> View<'_, ProtoStr> { + // SAFETY: The `ProtoStr` API guards against non-UTF-8 data. The runtime does + // not require `ProtoStr` to be UTF-8 if it could be mutated outside of these + // guards, such as through FFI. + unsafe { ProtoStr::from_utf8_unchecked(self.0.as_view()) } + } + + fn into_view<'shorter>(self) -> View<'shorter, ProtoStr> + where + 'msg: 'shorter, + { + // SAFETY: The `ProtoStr` API guards against non-UTF-8 data. The runtime does + // not require `ProtoStr` to be UTF-8 if it could be mutated outside of these + // guards, such as through FFI. + unsafe { ProtoStr::from_utf8_unchecked(self.0.into_view()) } + } +} + +impl<'msg> MutProxy<'msg> for StrPresentMutData<'msg> { + fn as_mut(&mut self) -> Mut<'_, ProtoStr> { + ProtoStrMut { bytes: self.0.as_mut() } + } + + fn into_mut<'shorter>(self) -> Mut<'shorter, ProtoStr> + where + 'msg: 'shorter, + { + ProtoStrMut { bytes: self.0.into_mut() } + } +} + +/// Non-exported newtype for `ProxiedWithPresence::AbsentData` +#[derive(Debug)] +pub struct StrAbsentMutData<'msg>(BytesAbsentMutData<'msg>); + +impl<'msg> ViewProxy<'msg> for StrAbsentMutData<'msg> { + type Proxied = ProtoStr; + + fn as_view(&self) -> View<'_, ProtoStr> { + // SAFETY: The `ProtoStr` API guards against non-UTF-8 data. The runtime does + // not require `ProtoStr` to be UTF-8 if it could be mutated outside of these + // guards, such as through FFI. + unsafe { ProtoStr::from_utf8_unchecked(self.0.as_view()) } + } + + fn into_view<'shorter>(self) -> View<'shorter, ProtoStr> + where + 'msg: 'shorter, + { + // SAFETY: The `ProtoStr` API guards against non-UTF-8 data. The runtime does + // not require `ProtoStr` to be UTF-8 if it could be mutated outside of these + // guards, such as through FFI. + unsafe { ProtoStr::from_utf8_unchecked(self.0.into_view()) } + } +} + +#[derive(Debug)] +pub struct ProtoStrMut<'msg> { + bytes: BytesMut<'msg>, +} + +impl<'msg> ProtoStrMut<'msg> { + /// Constructs a new `ProtoStrMut` from its internal, runtime-dependent + /// part. + #[doc(hidden)] + pub fn from_inner(_private: Private, inner: InnerBytesMut<'msg>) -> Self { + Self { bytes: BytesMut { inner } } + } + + /// Converts a `bytes` `FieldEntry` into a `string` one. Used by gencode. + #[doc(hidden)] + pub fn field_entry_from_bytes( + _private: Private, + field_entry: FieldEntry<'_, [u8]>, + ) -> FieldEntry { + match field_entry { + Optional::Set(present) => { + Optional::Set(PresentField::from_inner(Private, StrPresentMutData(present.inner))) + } + Optional::Unset(absent) => { + Optional::Unset(AbsentField::from_inner(Private, StrAbsentMutData(absent.inner))) + } + } + } + + /// Gets the current value of the field. + pub fn get(&self) -> &ProtoStr { + self.as_view() + } + + /// Sets the string to the given `val`, cloning any borrowed data. + /// + /// This method accepts both owned and borrowed strings; if the runtime + /// supports it, an owned value will not reallocate when setting the + /// string. + pub fn set(&mut self, val: impl SettableValue) { + val.set_on(Private, MutProxy::as_mut(self)) + } + + /// Truncates the string. + /// + /// Has no effect if `new_len` is larger than the current `len`. + /// + /// If `new_len` does not lie on a UTF-8 `char` boundary, behavior is + /// runtime-dependent. If this occurs, the runtime may: + /// + /// - Panic + /// - Truncate the string further to be on a `char` boundary. + /// - Truncate to `new_len`, resulting in a `ProtoStr` with a non-UTF8 tail. + pub fn truncate(&mut self, new_len: usize) { + self.bytes.truncate(new_len) + } + + /// Clears the string, setting it to the empty string. + /// + /// # Compared with `FieldEntry::clear` + /// + /// Note that this is different than marking an `optional string` field as + /// absent; if this cleared `string` is in an `optional`, + /// `FieldEntry::is_set` will still return `true` after this method is + /// invoked. + /// + /// This also means that if the field has a non-empty default, + /// `ProtoStrMut::clear` results in the accessor returning an empty string + /// while `FieldEntry::clear` results in the non-empty default. + /// + /// However, for a proto3 `string` that has implicit presence, there is no + /// distinction between these states: unset `string` is the same as empty + /// `string` and the default is always the empty string. + /// + /// In the C++ API, this is the difference between + /// `msg.clear_string_field()` + /// and `msg.mutable_string_field()->clear()`. + /// + /// Having the same name and signature as `FieldEntry::clear` makes code + /// that calls `field_mut().clear()` easier to migrate from implicit + /// to explicit presence. + pub fn clear(&mut self) { + self.truncate(0); + } +} + +impl Deref for ProtoStrMut<'_> { + type Target = ProtoStr; + fn deref(&self) -> &ProtoStr { + self.as_view() + } +} + +impl AsRef for ProtoStrMut<'_> { + fn as_ref(&self) -> &ProtoStr { + self.as_view() + } +} + +impl AsRef<[u8]> for ProtoStrMut<'_> { + fn as_ref(&self) -> &[u8] { + self.as_view().as_bytes() + } +} + +impl<'msg> ViewProxy<'msg> for ProtoStrMut<'msg> { + type Proxied = ProtoStr; + + fn as_view(&self) -> &ProtoStr { + // SAFETY: The `ProtoStr` API guards against non-UTF-8 data. The runtime does + // not require `ProtoStr` to be UTF-8 if it could be mutated outside of these + // guards, such as through FFI. + unsafe { ProtoStr::from_utf8_unchecked(self.bytes.as_view()) } + } + + fn into_view<'shorter>(self) -> &'shorter ProtoStr + where + 'msg: 'shorter, + { + unsafe { ProtoStr::from_utf8_unchecked(self.bytes.into_view()) } + } +} + +impl<'msg> MutProxy<'msg> for ProtoStrMut<'msg> { + fn as_mut(&mut self) -> ProtoStrMut<'_> { + ProtoStrMut { bytes: BytesMut { inner: self.bytes.inner } } + } + + fn into_mut<'shorter>(self) -> ProtoStrMut<'shorter> + where + 'msg: 'shorter, + { + ProtoStrMut { bytes: BytesMut { inner: self.bytes.inner } } + } +} + +impl SettableValue for &'_ ProtoStr { + fn set_on(self, _private: Private, mutator: ProtoStrMut<'_>) { + // SAFETY: A `ProtoStr` has the same UTF-8 validity requirement as the runtime. + unsafe { mutator.bytes.inner.set(self.as_bytes()) } + } + + fn set_on_absent( + self, + _private: Private, + absent_mutator: ::AbsentMutData<'_>, + ) -> ::PresentMutData<'_> { + // SAFETY: A `ProtoStr` has the same UTF-8 validity requirement as the runtime. + StrPresentMutData(unsafe { absent_mutator.0.set(self.as_bytes()) }) + } + + fn set_on_present( + self, + _private: Private, + present_mutator: ::PresentMutData<'_>, + ) { + // SAFETY: A `ProtoStr` has the same UTF-8 validity requirement as the runtime. + unsafe { + present_mutator.0.set(self.as_bytes()); + } + } +} + +impl SettableValue for &'_ str { + impl_forwarding_settable_value!(ProtoStr, self => ProtoStr::from_str(self)); +} + +impl SettableValue for String { + // TODO(b/293956360): Investigate taking ownership of this when allowed by the + // runtime. + impl_forwarding_settable_value!(ProtoStr, self => ProtoStr::from_str(&self)); +} + +impl SettableValue for Cow<'_, str> { + // TODO(b/293956360): Investigate taking ownership of this when allowed by the + // runtime. + impl_forwarding_settable_value!(ProtoStr, self => ProtoStr::from_str(&self)); +} + +impl Hash for ProtoStrMut<'_> { + fn hash(&self, state: &mut H) { + self.deref().hash(state) + } +} + +impl Eq for ProtoStrMut<'_> {} +impl<'msg> Ord for ProtoStrMut<'msg> { + fn cmp(&self, other: &ProtoStrMut<'msg>) -> Ordering { + self.deref().cmp(other.deref()) + } +} /// Implements `PartialCmp` and `PartialEq` for the `lhs` against the `rhs` /// using `AsRef<[u8]>`. @@ -493,12 +809,19 @@ impl_bytes_partial_cmp!( // `ProtoStr` against protobuf types <()> ProtoStr => ProtoStr, + <('a)> ProtoStr => ProtoStrMut<'a>, // `ProtoStr` against foreign types <()> ProtoStr => str, <()> str => ProtoStr, - // TODO(b/285309330): `ProtoStrMut` impls + // `ProtoStrMut` against protobuf types + <('a, 'b)> ProtoStrMut<'a> => ProtoStrMut<'b>, + <('a)> ProtoStrMut<'a> => ProtoStr, + + // `ProtoStrMut` against foreign types + <('a)> ProtoStrMut<'a> => str, + <('a)> str => ProtoStrMut<'a>, ); #[cfg(test)] diff --git a/rust/test/shared/accessors_proto3_test.rs b/rust/test/shared/accessors_proto3_test.rs index 4e8f55eb7d..6228523b80 100644 --- a/rust/test/shared/accessors_proto3_test.rs +++ b/rust/test/shared/accessors_proto3_test.rs @@ -49,8 +49,8 @@ fn test_fixed32_accessors() { #[test] fn test_bytes_accessors() { let mut msg = TestAllTypes::new(); - // Note: even though its named 'optional_bytes' the field is actually not proto3 - // optional, so it does not support presence. + // Note: even though it's named 'optional_bytes', the field is actually not + // proto3 optional, so it does not support presence. assert_eq!(msg.optional_bytes(), b""); assert_eq!(msg.optional_bytes_mut().get(), b""); @@ -120,6 +120,73 @@ fn test_optional_bytes_accessors() { assert_eq!(msg.optional_bytes_mut().or_default().get(), b"\xffbinary\x85non-utf8"); } +#[test] +fn test_string_accessors() { + let mut msg = TestAllTypes::new(); + // Note: even though it's named 'optional_string', the field is actually not + // proto3 optional, so it does not support presence. + assert_eq!(msg.optional_string(), ""); + assert_eq!(msg.optional_string_mut().get(), ""); + + msg.optional_string_mut().set("accessors_test"); + assert_eq!(msg.optional_string(), "accessors_test"); + assert_eq!(msg.optional_string_mut().get(), "accessors_test"); + + { + let s = String::from("hello world"); + msg.optional_string_mut().set(&s[..]); + } + assert_eq!(msg.optional_string(), "hello world"); + assert_eq!(msg.optional_string_mut().get(), "hello world"); + + msg.optional_string_mut().clear(); + assert_eq!(msg.optional_string(), ""); + assert_eq!(msg.optional_string_mut().get(), ""); + + msg.optional_string_mut().set(""); + assert_eq!(msg.optional_string(), ""); + assert_eq!(msg.optional_string_mut().get(), ""); +} + +#[test] +fn test_optional_string_accessors() { + let mut msg = TestProto3Optional::new(); + assert_eq!(msg.optional_string(), ""); + assert_eq!(msg.optional_string_opt(), Optional::Unset("".into())); + assert_eq!(msg.optional_string_mut().get(), ""); + assert!(msg.optional_string_mut().is_unset()); + + { + let s = String::from("hello world"); + msg.optional_string_mut().set(&s[..]); + } + assert_eq!(msg.optional_string(), "hello world"); + assert_eq!(msg.optional_string_opt(), Optional::Set("hello world".into())); + assert!(msg.optional_string_mut().is_set()); + assert_eq!(msg.optional_string_mut().get(), "hello world"); + + msg.optional_string_mut().or_default().set("accessors_test"); + assert_eq!(msg.optional_string(), "accessors_test"); + assert_eq!(msg.optional_string_opt(), Optional::Set("accessors_test".into())); + assert!(msg.optional_string_mut().is_set()); + assert_eq!(msg.optional_string_mut().get(), "accessors_test"); + assert_eq!(msg.optional_string_mut().or_default().get(), "accessors_test"); + + msg.optional_string_mut().clear(); + assert_eq!(msg.optional_string(), ""); + assert_eq!(msg.optional_string_opt(), Optional::Unset("".into())); + assert!(msg.optional_string_mut().is_unset()); + + msg.optional_string_mut().set(""); + assert_eq!(msg.optional_string(), ""); + assert_eq!(msg.optional_string_opt(), Optional::Set("".into())); + + msg.optional_string_mut().clear(); + msg.optional_string_mut().or_default(); + assert_eq!(msg.optional_string(), ""); + assert_eq!(msg.optional_string_opt(), Optional::Set("".into())); +} + #[test] fn test_oneof_accessors() { let mut msg = TestAllTypes::new(); diff --git a/rust/test/shared/accessors_test.rs b/rust/test/shared/accessors_test.rs index 4bedb53bc9..c323eee631 100644 --- a/rust/test/shared/accessors_test.rs +++ b/rust/test/shared/accessors_test.rs @@ -307,6 +307,84 @@ fn test_nonempty_default_bytes_accessors() { assert_eq!(msg.default_bytes_mut().or_default().get(), b"\xffbinary\x85non-utf8"); } +#[test] +fn test_optional_string_accessors() { + let mut msg = TestAllTypes::new(); + assert_eq!(msg.optional_string(), ""); + assert_eq!(msg.optional_string_opt(), Optional::Unset("".into())); + assert_eq!(msg.optional_string_mut().get(), ""); + assert!(msg.optional_string_mut().is_unset()); + + { + let s = String::from("hello world"); + msg.optional_string_mut().set(&s[..]); + } + assert_eq!(msg.optional_string(), "hello world"); + assert_eq!(msg.optional_string_opt(), Optional::Set("hello world".into())); + assert!(msg.optional_string_mut().is_set()); + assert_eq!(msg.optional_string_mut().get(), "hello world"); + + msg.optional_string_mut().or_default().set("accessors_test"); + assert_eq!(msg.optional_string(), "accessors_test"); + assert_eq!(msg.optional_string_opt(), Optional::Set("accessors_test".into())); + assert!(msg.optional_string_mut().is_set()); + assert_eq!(msg.optional_string_mut().get(), "accessors_test"); + assert_eq!(msg.optional_string_mut().or_default().get(), "accessors_test"); + + msg.optional_string_mut().clear(); + assert_eq!(msg.optional_string(), ""); + assert_eq!(msg.optional_string_opt(), Optional::Unset("".into())); + assert!(msg.optional_string_mut().is_unset()); + + msg.optional_string_mut().set(""); + assert_eq!(msg.optional_string(), ""); + assert_eq!(msg.optional_string_opt(), Optional::Set("".into())); + + msg.optional_string_mut().clear(); + msg.optional_string_mut().or_default(); + assert_eq!(msg.optional_string(), ""); + assert_eq!(msg.optional_string_opt(), Optional::Set("".into())); +} + +#[test] +fn test_nonempty_default_string_accessors() { + let mut msg = TestAllTypes::new(); + assert_eq!(msg.default_string(), "hello"); + assert_eq!(msg.default_string_opt(), Optional::Unset("hello".into())); + assert_eq!(msg.default_string_mut().get(), "hello"); + assert!(msg.default_string_mut().is_unset()); + + { + let s = String::from("hello world"); + msg.default_string_mut().set(&s[..]); + } + assert_eq!(msg.default_string(), "hello world"); + assert_eq!(msg.default_string_opt(), Optional::Set("hello world".into())); + assert!(msg.default_string_mut().is_set()); + assert_eq!(msg.default_string_mut().get(), "hello world"); + + msg.default_string_mut().or_default().set("accessors_test"); + assert_eq!(msg.default_string(), "accessors_test"); + assert_eq!(msg.default_string_opt(), Optional::Set("accessors_test".into())); + assert!(msg.default_string_mut().is_set()); + assert_eq!(msg.default_string_mut().get(), "accessors_test"); + assert_eq!(msg.default_string_mut().or_default().get(), "accessors_test"); + + msg.default_string_mut().clear(); + assert_eq!(msg.default_string(), "hello"); + assert_eq!(msg.default_string_opt(), Optional::Unset("hello".into())); + assert!(msg.default_string_mut().is_unset()); + + msg.default_string_mut().set(""); + assert_eq!(msg.default_string(), ""); + assert_eq!(msg.default_string_opt(), Optional::Set("".into())); + + msg.default_string_mut().clear(); + msg.default_string_mut().or_default(); + assert_eq!(msg.default_string(), "hello"); + assert_eq!(msg.default_string_opt(), Optional::Set("hello".into())); +} + #[test] fn test_singular_msg_field() { let msg = TestAllTypes::new(); diff --git a/src/google/protobuf/compiler/rust/BUILD.bazel b/src/google/protobuf/compiler/rust/BUILD.bazel index cda89d786a..f404beae97 100644 --- a/src/google/protobuf/compiler/rust/BUILD.bazel +++ b/src/google/protobuf/compiler/rust/BUILD.bazel @@ -51,9 +51,9 @@ cc_library( name = "accessors", srcs = [ "accessors/accessors.cc", - "accessors/singular_bytes.cc", "accessors/singular_message.cc", "accessors/singular_scalar.cc", + "accessors/singular_string.cc", "accessors/unsupported_field.cc", ], hdrs = [ diff --git a/src/google/protobuf/compiler/rust/accessors/accessor_generator.h b/src/google/protobuf/compiler/rust/accessors/accessor_generator.h index 87e8ed886a..16808a16f9 100644 --- a/src/google/protobuf/compiler/rust/accessors/accessor_generator.h +++ b/src/google/protobuf/compiler/rust/accessors/accessor_generator.h @@ -93,9 +93,9 @@ class SingularScalar final : public AccessorGenerator { void InThunkCc(Context field) const override; }; -class SingularBytes final : public AccessorGenerator { +class SingularString final : public AccessorGenerator { public: - ~SingularBytes() override = default; + ~SingularString() override = default; void InMsgImpl(Context field) const override; void InExternC(Context field) const override; void InThunkCc(Context field) const override; diff --git a/src/google/protobuf/compiler/rust/accessors/accessors.cc b/src/google/protobuf/compiler/rust/accessors/accessors.cc index 1ca22f1dd6..15caa16518 100644 --- a/src/google/protobuf/compiler/rust/accessors/accessors.cc +++ b/src/google/protobuf/compiler/rust/accessors/accessors.cc @@ -72,7 +72,8 @@ std::unique_ptr AccessorGeneratorFor( case FieldDescriptor::TYPE_BOOL: return std::make_unique(); case FieldDescriptor::TYPE_BYTES: - return std::make_unique(); + case FieldDescriptor::TYPE_STRING: + return std::make_unique(); case FieldDescriptor::TYPE_MESSAGE: return std::make_unique(); diff --git a/src/google/protobuf/compiler/rust/accessors/singular_bytes.cc b/src/google/protobuf/compiler/rust/accessors/singular_string.cc similarity index 74% rename from src/google/protobuf/compiler/rust/accessors/singular_bytes.cc rename to src/google/protobuf/compiler/rust/accessors/singular_string.cc index 089ff63467..02e47f7f38 100644 --- a/src/google/protobuf/compiler/rust/accessors/singular_bytes.cc +++ b/src/google/protobuf/compiler/rust/accessors/singular_string.cc @@ -43,27 +43,42 @@ namespace protobuf { namespace compiler { namespace rust { -void SingularBytes::InMsgImpl(Context field) const { +void SingularString::InMsgImpl(Context field) const { std::string hazzer_thunk = Thunk(field, "has"); std::string getter_thunk = Thunk(field, "get"); std::string setter_thunk = Thunk(field, "set"); + std::string proxied_type = PrimitiveRsTypeName(field.desc()); + auto transform_view = [&] { + if (field.desc().type() == FieldDescriptor::TYPE_STRING) { + field.Emit(R"rs( + // SAFETY: The runtime doesn't require ProtoStr to be UTF-8. + unsafe { $pb$::ProtoStr::from_utf8_unchecked(view) } + )rs"); + } else { + field.Emit("view"); + } + }; field.Emit( { {"field", field.desc().name()}, {"hazzer_thunk", hazzer_thunk}, {"getter_thunk", getter_thunk}, {"setter_thunk", setter_thunk}, + {"proxied_type", proxied_type}, + {"transform_view", transform_view}, {"field_optional_getter", [&] { if (!field.desc().is_optional()) return; if (!field.desc().has_presence()) return; field.Emit({{"hazzer_thunk", hazzer_thunk}, - {"getter_thunk", getter_thunk}}, + {"getter_thunk", getter_thunk}, + {"transform_view", transform_view}}, R"rs( - pub fn $field$_opt(&self) -> $pb$::Optional<&[u8]> { + pub fn $field$_opt(&self) -> $pb$::Optional<&$proxied_type$> { unsafe { + let view = $getter_thunk$(self.inner.msg).as_ref(); $pb$::Optional::new( - $getter_thunk$(self.inner.msg).as_ref(), + $transform_view$ , $hazzer_thunk$(self.inner.msg) ) } @@ -76,15 +91,30 @@ void SingularBytes::InMsgImpl(Context field) const { field.Emit( { {"field", field.desc().name()}, + {"proxied_type", proxied_type}, {"default_val", absl::CHexEscape(field.desc().default_value_string())}, + {"view_type", proxied_type}, + {"transform_field_entry", + [&] { + if (field.desc().type() == + FieldDescriptor::TYPE_STRING) { + field.Emit(R"rs( + $pb$::ProtoStrMut::field_entry_from_bytes( + $pbi$::Private, out + ) + )rs"); + } else { + field.Emit("out"); + } + }}, {"hazzer_thunk", hazzer_thunk}, {"getter_thunk", getter_thunk}, {"setter_thunk", setter_thunk}, {"clearer_thunk", Thunk(field, "clear")}, }, R"rs( - pub fn $field$_mut(&mut self) -> $pb$::FieldEntry<'_, [u8]> { + pub fn $field$_mut(&mut self) -> $pb$::FieldEntry<'_, $proxied_type$> { static VTABLE: $pbi$::BytesOptionalMutVTable = unsafe { $pbi$::BytesOptionalMutVTable::new( $pbi$::Private, @@ -94,7 +124,7 @@ void SingularBytes::InMsgImpl(Context field) const { b"$default_val$", ) }; - unsafe { + let out = unsafe { let has = $hazzer_thunk$(self.inner.msg); $pbi$::new_vtable_field_entry( $pbi$::Private, @@ -103,15 +133,17 @@ void SingularBytes::InMsgImpl(Context field) const { &VTABLE, has, ) - } + }; + $transform_field_entry$ } )rs"); } else { field.Emit({{"field", field.desc().name()}, + {"proxied_type", proxied_type}, {"getter_thunk", getter_thunk}, {"setter_thunk", setter_thunk}}, R"rs( - pub fn $field$_mut(&mut self) -> $pb$::BytesMut<'_> { + pub fn $field$_mut(&mut self) -> $pb$::Mut<'_, $proxied_type$> { static VTABLE: $pbi$::BytesMutVTable = unsafe { $pbi$::BytesMutVTable::new( $pbi$::Private, @@ -120,7 +152,7 @@ void SingularBytes::InMsgImpl(Context field) const { ) }; unsafe { - $pb$::BytesMut::from_inner( + <$pb$::Mut<$proxied_type$>>::from_inner( $pbi$::Private, $pbi$::RawVTableMutator::new( $pbi$::Private, @@ -136,10 +168,9 @@ void SingularBytes::InMsgImpl(Context field) const { }}, }, R"rs( - pub fn r#$field$(&self) -> &[u8] { - unsafe { - $getter_thunk$(self.inner.msg).as_ref() - } + pub fn r#$field$(&self) -> &$proxied_type$ { + let view = unsafe { $getter_thunk$(self.inner.msg).as_ref() }; + $transform_view$ } $field_optional_getter$ @@ -147,7 +178,7 @@ void SingularBytes::InMsgImpl(Context field) const { )rs"); } -void SingularBytes::InExternC(Context field) const { +void SingularString::InExternC(Context field) const { field.Emit({{"hazzer_thunk", Thunk(field, "has")}, {"getter_thunk", Thunk(field, "get")}, {"setter_thunk", Thunk(field, "set")}, @@ -156,8 +187,8 @@ void SingularBytes::InExternC(Context field) const { [&] { if (field.desc().has_presence()) { field.Emit(R"rs( - fn $hazzer_thunk$(raw_msg: $pbi$::RawMessage) -> bool; - )rs"); + fn $hazzer_thunk$(raw_msg: $pbi$::RawMessage) -> bool; + )rs"); } }}}, R"rs( @@ -168,8 +199,8 @@ void SingularBytes::InExternC(Context field) const { )rs"); } -void SingularBytes::InThunkCc(Context field) const { - field.Emit({{"field", field.desc().name()}, +void SingularString::InThunkCc(Context field) const { + field.Emit({{"field", cpp::FieldName(&field.desc())}, {"QualifiedMsg", cpp::QualifiedClassName(field.desc().containing_type())}, {"hazzer_thunk", Thunk(field, "has")}, @@ -182,7 +213,9 @@ void SingularBytes::InThunkCc(Context field) const { field.Emit(R"cc( bool $hazzer_thunk$($QualifiedMsg$* msg) { return msg->has_$field$(); - })cc"); + } + void $clearer_thunk$($QualifiedMsg$* msg) { msg->clear_$field$(); } + )cc"); } }}}, R"cc( @@ -194,7 +227,6 @@ void SingularBytes::InThunkCc(Context field) const { void $setter_thunk$($QualifiedMsg$* msg, const char* ptr, ::std::size_t size) { msg->set_$field$(absl::string_view(ptr, size)); } - void $clearer_thunk$($QualifiedMsg$* msg) { msg->clear_$field$(); } )cc"); } diff --git a/src/google/protobuf/compiler/rust/naming.cc b/src/google/protobuf/compiler/rust/naming.cc index f686bcf877..e0821414ff 100644 --- a/src/google/protobuf/compiler/rust/naming.cc +++ b/src/google/protobuf/compiler/rust/naming.cc @@ -148,7 +148,9 @@ std::string PrimitiveRsTypeName(const FieldDescriptor& desc) { case FieldDescriptor::TYPE_DOUBLE: return "f64"; case FieldDescriptor::TYPE_BYTES: - return "&[u8]"; + return "[u8]"; + case FieldDescriptor::TYPE_STRING: + return "::__pb::ProtoStr"; default: break; }