diff --git a/src/google/protobuf/compiler/rust/message.cc b/src/google/protobuf/compiler/rust/message.cc index 6f5454d1a5..02f85d0ab6 100644 --- a/src/google/protobuf/compiler/rust/message.cc +++ b/src/google/protobuf/compiler/rust/message.cc @@ -9,6 +9,7 @@ #include "absl/log/absl_check.h" #include "absl/log/absl_log.h" +#include "absl/strings/str_format.h" #include "absl/strings/string_view.h" #include "google/protobuf/compiler/cpp/helpers.h" #include "google/protobuf/compiler/cpp/names.h" @@ -163,6 +164,10 @@ void MessageDrop(Context& ctx, const Descriptor& msg) { )rs"); } +bool IsStringOrBytes(FieldDescriptor::Type t) { + return t == FieldDescriptor::TYPE_STRING || t == FieldDescriptor::TYPE_BYTES; +} + void GetterForViewOrMut(Context& ctx, const FieldDescriptor& field, bool is_mut) { auto fieldName = field.name(); @@ -218,89 +223,62 @@ void GetterForViewOrMut(Context& ctx, const FieldDescriptor& field, } auto rsType = PrimitiveRsTypeName(field); - if (fieldType == FieldDescriptor::TYPE_STRING || - fieldType == FieldDescriptor::TYPE_BYTES) { - ctx.Emit({{"field", fieldName}, - {"self", self}, - {"getter_thunk", getter_thunk}, - {"setter_thunk", setter_thunk}, - {"RsType", rsType}, - {"maybe_mutator", - [&] { - if (is_mut) { - ctx.Emit({}, R"rs( - pub fn r#$field$_mut(&self) -> $pb$::Mut<'_, $RsType$> { - static VTABLE: $pbi$::BytesMutVTable = - $pbi$::BytesMutVTable::new( - $pbi$::Private, - $getter_thunk$, - $setter_thunk$, - ); - - unsafe { - <$pb$::Mut<$RsType$>>::from_inner( - $pbi$::Private, - $pbi$::RawVTableMutator::new( - $pbi$::Private, - self.inner, - &VTABLE, - ) - ) - } - } - )rs"); - } - }}}, - R"rs( - pub fn r#$field$(&self) -> $pb$::View<'_, $RsType$> { - let s = unsafe { $getter_thunk$($self$).as_ref() }; - unsafe { __pb::ProtoStr::from_utf8_unchecked(s).into() } - } - - $maybe_mutator$ - )rs"); - } else { - ctx.Emit({{"field", fieldName}, - {"getter_thunk", getter_thunk}, - {"setter_thunk", setter_thunk}, - {"clearer_thunk", clearer_thunk}, - {"self", self}, - {"RsType", rsType}, - {"maybe_mutator", - [&] { - // TODO: once the rust public api is accessible, - // by tooling, ensure that this only appears for the - // mutational pathway - if (is_mut && fieldType) { - ctx.Emit({}, R"rs( - pub fn r#$field$_mut(&self) -> $pb$::Mut<'_, $RsType$> { - static VTABLE: $pbi$::PrimitiveVTable<$RsType$> = - $pbi$::PrimitiveVTable::new( - $pbi$::Private, - $getter_thunk$, - $setter_thunk$); - unsafe { - $pb$::PrimitiveMut::from_inner( + auto asRef = IsStringOrBytes(fieldType) ? ".as_ref()" : ""; + auto vtable = + IsStringOrBytes(fieldType) ? "BytesMutVTable" : "PrimitiveVTable"; + // PrimitiveVtable is parameterized based on the underlying primitive, like + // u32 so we need to provide this additional type arg + auto optionalTypeArgs = + IsStringOrBytes(fieldType) ? "" : absl::StrFormat("<%s>", rsType); + // need to stuff ProtoStr and [u8] behind a reference since they are DSTs + auto stringTransform = + IsStringOrBytes(fieldType) + ? "unsafe { __pb::ProtoStr::from_utf8_unchecked(res).into() }" + : "res"; + + ctx.Emit({{"field", fieldName}, + {"getter_thunk", getter_thunk}, + {"setter_thunk", setter_thunk}, + {"clearer_thunk", clearer_thunk}, + {"self", self}, + {"RsType", rsType}, + {"as_ref", asRef}, + {"vtable", vtable}, + {"optional_type_args", optionalTypeArgs}, + {"string_transform", stringTransform}, + {"maybe_mutator", + [&] { + // TODO: check mutational pathway genn'd correctly + if (is_mut) { + ctx.Emit({}, R"rs( + pub fn r#$field$_mut(&self) -> $pb$::Mut<'_, $RsType$> { + static VTABLE: $pbi$::$vtable$$optional_type_args$ = + $pbi$::$vtable$::new( + $pbi$::Private, + $getter_thunk$, + $setter_thunk$); + unsafe { + <$pb$::Mut<$RsType$>>::from_inner( + $pbi$::Private, + $pbi$::RawVTableMutator::new( $pbi$::Private, - $pbi$::RawVTableMutator::new( - $pbi$::Private, - self.inner, - &VTABLE - ), - ) - } + self.inner, + &VTABLE + ), + ) } - )rs"); - } - }}}, - R"rs( - pub fn r#$field$(&self) -> $pb$::View<'_, $RsType$> { - unsafe { $getter_thunk$($self$) } - } + } + )rs"); + } + }}}, + R"rs( + pub fn r#$field$(&self) -> $pb$::View<'_, $RsType$> { + let res = unsafe { $getter_thunk$($self$)$as_ref$ }; + $string_transform$ + } - $maybe_mutator$ - )rs"); - } + $maybe_mutator$ + )rs"); } void AccessorsForViewOrMut(Context& ctx, const Descriptor& msg, bool is_mut) {