diff --git a/rust/test/BUILD b/rust/test/BUILD index d0e86bb03f..3b05958c8a 100644 --- a/rust/test/BUILD +++ b/rust/test/BUILD @@ -342,23 +342,23 @@ rust_cc_proto_library( ) proto_library( - name = "reserved_proto", + name = "bad_names_proto", testonly = True, - srcs = ["reserved.proto"], + srcs = ["bad_names.proto"], ) rust_cc_proto_library( - name = "reserved_cc_rust_proto", + name = "bad_names_cc_rust_proto", testonly = True, visibility = ["//rust/test/shared:__subpackages__"], - deps = [":reserved_proto"], + deps = [":bad_names_proto"], ) rust_upb_proto_library( - name = "reserved_upb_rust_proto", + name = "bad_names_upb_rust_proto", testonly = True, visibility = ["//rust/test/shared:__subpackages__"], - deps = [":reserved_proto"], + deps = [":bad_names_proto"], ) proto_library( diff --git a/rust/test/reserved.proto b/rust/test/bad_names.proto similarity index 53% rename from rust/test/reserved.proto rename to rust/test/bad_names.proto index 1f8c2a937d..21fbf3b7cb 100644 --- a/rust/test/reserved.proto +++ b/rust/test/bad_names.proto @@ -6,8 +6,10 @@ // https://developers.google.com/open-source/licenses/bsd // LINT: LEGACY_NAMES -// The purpose of this file is to be as hostile as possible to reserved words -// to the Rust language and ensure it still works. + +// The purpose of this file is to be hostile on field/message/enum naming and +// ensure that it works (e.g. collisions between names and language keywords, +// collisions between two different field's accessor's names). syntax = "proto2"; @@ -36,3 +38,17 @@ message Ref { .type.type.Pub.Self const = 3; } } + +// A message where the accessors would collide that should still work. Note that +// not all collisions problems are avoided, not least because C++ Proto does not +// avoid all possible collisions (eg a field `x` and `clear_x` will often not +// compile on C++). +message AccessorsCollide { + message X {} + message SetX {} + optional SetX set_x = 2; + optional X x = 3; + oneof o { + bool x_mut = 5; + } +} diff --git a/rust/test/shared/BUILD b/rust/test/shared/BUILD index 7210cc2374..460542bf46 100644 --- a/rust/test/shared/BUILD +++ b/rust/test/shared/BUILD @@ -159,24 +159,26 @@ rust_test( ) rust_test( - name = "reserved_cpp_test", - srcs = ["reserved_test.rs"], + name = "bad_names_cpp_test", + srcs = ["bad_names_test.rs"], deps = [ - "//rust/test:reserved_cc_rust_proto", + "//rust/test:bad_names_cc_rust_proto", "//rust/test:unittest_cc_rust_proto", "@crate_index//:googletest", ], ) -rust_test( - name = "reserved_upb_test", - srcs = ["reserved_test.rs"], - deps = [ - "//rust/test:reserved_upb_rust_proto", - "//rust/test:unittest_upb_rust_proto", - "@crate_index//:googletest", - ], -) +# TODO: This test currently fails on upb due to the thunk names not correctly matching +# the upb C codegen collision avoidance. +# rust_test( +# name = "bad_names_upb_test", +# srcs = ["bad_names_test.rs"], +# deps = [ +# "@crate_index//:googletest", +# "//rust/test:bad_names_upb_rust_proto", +# "//rust/test:unittest_upb_rust_proto", +# ], +# ) rust_test( name = "nested_types_cpp_test", diff --git a/rust/test/shared/reserved_test.rs b/rust/test/shared/bad_names_test.rs similarity index 68% rename from rust/test/shared/reserved_test.rs rename to rust/test/shared/bad_names_test.rs index 0d2cb52bb6..a4b133b92f 100644 --- a/rust/test/shared/reserved_test.rs +++ b/rust/test/shared/bad_names_test.rs @@ -5,10 +5,8 @@ // license that can be found in the LICENSE file or at // https://developers.google.com/open-source/licenses/bsd -/// Test covering proto compilation with reserved words. +use bad_names_proto::*; use googletest::prelude::*; -use reserved_proto::Self__mangled_because_ident_isnt_a_legal_raw_identifier; -use reserved_proto::{r#enum, Ref}; #[test] fn test_reserved_keyword_in_accessors() { @@ -22,3 +20,13 @@ fn test_reserved_keyword_in_messages() { let _ = r#enum::new(); let _ = Ref::new().r#const(); } + +#[test] +fn test_collision_in_accessors() { + let mut m = AccessorsCollide::new(); + m.set_x_mut_5(false); + assert_that!(m.x_mut_5(), eq(false)); + assert_that!(m.has_x_mut_5(), eq(true)); + assert_that!(m.has_x(), eq(false)); + assert_that!(m.has_set_x_2(), eq(false)); +} diff --git a/src/google/protobuf/compiler/rust/accessors/map.cc b/src/google/protobuf/compiler/rust/accessors/map.cc index ee0d089784..75aa5007ee 100644 --- a/src/google/protobuf/compiler/rust/accessors/map.cc +++ b/src/google/protobuf/compiler/rust/accessors/map.cc @@ -39,8 +39,9 @@ void Map::InMsgImpl(Context& ctx, const FieldDescriptor& field, AccessorCase accessor_case) const { auto& key_type = *field.message_type()->map_key(); auto& value_type = *field.message_type()->map_value(); + std::string field_name = FieldNameWithCollisionAvoidance(field); - ctx.Emit({{"field", RsSafeName(field.name())}, + ctx.Emit({{"field", RsSafeName(field_name)}, {"Key", RsTypePath(ctx, key_type)}, {"Value", RsTypePath(ctx, value_type)}, {"view_lifetime", ViewLifetime(accessor_case)}, diff --git a/src/google/protobuf/compiler/rust/accessors/repeated_field.cc b/src/google/protobuf/compiler/rust/accessors/repeated_field.cc index c48b437847..e1e14e646f 100644 --- a/src/google/protobuf/compiler/rust/accessors/repeated_field.cc +++ b/src/google/protobuf/compiler/rust/accessors/repeated_field.cc @@ -22,7 +22,8 @@ namespace rust { void RepeatedField::InMsgImpl(Context& ctx, const FieldDescriptor& field, AccessorCase accessor_case) const { - ctx.Emit({{"field", RsSafeName(field.name())}, + std::string field_name = FieldNameWithCollisionAvoidance(field); + ctx.Emit({{"field", RsSafeName(field_name)}, {"RsType", RsTypePath(ctx, field)}, {"view_lifetime", ViewLifetime(accessor_case)}, {"view_self", ViewReceiver(accessor_case)}, diff --git a/src/google/protobuf/compiler/rust/accessors/singular_message.cc b/src/google/protobuf/compiler/rust/accessors/singular_message.cc index 8f3ea9bfe0..73acd60164 100644 --- a/src/google/protobuf/compiler/rust/accessors/singular_message.cc +++ b/src/google/protobuf/compiler/rust/accessors/singular_message.cc @@ -24,9 +24,10 @@ void SingularMessage::InMsgImpl(Context& ctx, const FieldDescriptor& field, AccessorCase accessor_case) const { // fully qualified message name with modules prefixed std::string msg_type = RsTypePath(ctx, field); + std::string field_name = FieldNameWithCollisionAvoidance(field); ctx.Emit({{"msg_type", msg_type}, - {"field", RsSafeName(field.name())}, - {"raw_field_name", field.name()}, + {"field", RsSafeName(field_name)}, + {"raw_field_name", field_name}, {"view_lifetime", ViewLifetime(accessor_case)}, {"view_self", ViewReceiver(accessor_case)}, {"getter_thunk", ThunkName(ctx, field, "get")}, diff --git a/src/google/protobuf/compiler/rust/accessors/singular_scalar.cc b/src/google/protobuf/compiler/rust/accessors/singular_scalar.cc index bf8bf622d3..2c3546cbd9 100644 --- a/src/google/protobuf/compiler/rust/accessors/singular_scalar.cc +++ b/src/google/protobuf/compiler/rust/accessors/singular_scalar.cc @@ -23,10 +23,11 @@ namespace rust { void SingularScalar::InMsgImpl(Context& ctx, const FieldDescriptor& field, AccessorCase accessor_case) const { + std::string field_name = FieldNameWithCollisionAvoidance(field); ctx.Emit( { - {"field", RsSafeName(field.name())}, - {"raw_field_name", field.name()}, // Never r# prefixed + {"field", RsSafeName(field_name)}, + {"raw_field_name", field_name}, // Never r# prefixed {"view_self", ViewReceiver(accessor_case)}, {"Scalar", RsTypePath(ctx, field)}, {"hazzer_thunk", ThunkName(ctx, field, "has")}, diff --git a/src/google/protobuf/compiler/rust/accessors/singular_string.cc b/src/google/protobuf/compiler/rust/accessors/singular_string.cc index a9a8b27940..1de8f5b7ae 100644 --- a/src/google/protobuf/compiler/rust/accessors/singular_string.cc +++ b/src/google/protobuf/compiler/rust/accessors/singular_string.cc @@ -5,6 +5,8 @@ // license that can be found in the LICENSE file or at // https://developers.google.com/open-source/licenses/bsd +#include + #include "absl/strings/string_view.h" #include "google/protobuf/compiler/cpp/helpers.h" #include "google/protobuf/compiler/rust/accessors/accessor_case.h" @@ -21,10 +23,11 @@ namespace rust { void SingularString::InMsgImpl(Context& ctx, const FieldDescriptor& field, AccessorCase accessor_case) const { + std::string field_name = FieldNameWithCollisionAvoidance(field); ctx.Emit( { - {"field", RsSafeName(field.name())}, - {"raw_field_name", field.name()}, + {"field", RsSafeName(field_name)}, + {"raw_field_name", field_name}, {"hazzer_thunk", ThunkName(ctx, field, "has")}, {"getter_thunk", ThunkName(ctx, field, "get")}, {"setter_thunk", ThunkName(ctx, field, "set")}, diff --git a/src/google/protobuf/compiler/rust/naming.cc b/src/google/protobuf/compiler/rust/naming.cc index 6879e4ae45..78d96b8393 100644 --- a/src/google/protobuf/compiler/rust/naming.cc +++ b/src/google/protobuf/compiler/rust/naming.cc @@ -270,6 +270,40 @@ std::string FieldInfoComment(Context& ctx, const FieldDescriptor& field) { return comment; } +static constexpr absl::string_view kAccessorPrefixes[] = {"clear_", "has_", + "set_"}; + +static constexpr absl::string_view kAccessorSuffixes[] = {"_mut", "_opt"}; + +std::string FieldNameWithCollisionAvoidance(const FieldDescriptor& field) { + absl::string_view name = field.name(); + const Descriptor& msg = *field.containing_type(); + + for (absl::string_view prefix : kAccessorPrefixes) { + if (absl::StartsWith(name, prefix)) { + absl::string_view without_prefix = name; + without_prefix.remove_prefix(prefix.size()); + + if (msg.FindFieldByName(without_prefix) != nullptr) { + return absl::StrCat(name, "_", field.number()); + } + } + } + + for (absl::string_view suffix : kAccessorSuffixes) { + if (absl::EndsWith(name, suffix)) { + absl::string_view without_suffix = name; + without_suffix.remove_suffix(suffix.size()); + + if (msg.FindFieldByName(without_suffix) != nullptr) { + return absl::StrCat(name, "_", field.number()); + } + } + } + + return std::string(name); +} + std::string RsSafeName(absl::string_view name) { if (!IsLegalRawIdentifierName(name)) { return absl::StrCat(name, diff --git a/src/google/protobuf/compiler/rust/naming.h b/src/google/protobuf/compiler/rust/naming.h index f3d7af797e..70eebe1b64 100644 --- a/src/google/protobuf/compiler/rust/naming.h +++ b/src/google/protobuf/compiler/rust/naming.h @@ -55,6 +55,25 @@ std::string OneofCaseRsName(const FieldDescriptor& oneof_field); std::string FieldInfoComment(Context& ctx, const FieldDescriptor& field); +// Return how to name a field with 'collision avoidance'. This adds a suffix +// of the field number to the field name if it appears that it will collide with +// another field's non-getter accessor. +// +// For example, for the message: +// message M { bool set_x = 1; int32 x = 2; string x_mut = 8; } +// All accessors for the field `set_x` will be constructed as though the field +// was instead named `set_x_1`, and all accessors for `x_mut` will be as though +// the field was instead named `x_mut_8`. +// +// This is a best-effort heuristic to avoid realistic accidental +// collisions. It is still possible to create a message definition that will +// have a collision, and it may rename a field even if there's no collision (as +// in the case of x_mut in the example). +// +// Note the returned name may still be a rust keyword: RsSafeName() should +// additionally be used if there is no prefix/suffix being appended to the name. +std::string FieldNameWithCollisionAvoidance(const FieldDescriptor& field); + // Returns how to 'spell' the provided name in Rust, which is the provided name // verbatim unless it is a Rust keyword that isn't a legal symbol name. std::string RsSafeName(absl::string_view name); diff --git a/src/google/protobuf/compiler/rust/oneof.cc b/src/google/protobuf/compiler/rust/oneof.cc index 8453af898b..fc8cef2f18 100644 --- a/src/google/protobuf/compiler/rust/oneof.cc +++ b/src/google/protobuf/compiler/rust/oneof.cc @@ -193,10 +193,11 @@ void GenerateOneofAccessors(Context& ctx, const OneofDescriptor& oneof, if (rs_type.empty()) { continue; } + std::string field_name = FieldNameWithCollisionAvoidance(field); ctx.Emit( { {"case", OneofCaseRsName(field)}, - {"rs_getter", RsSafeName(field.name())}, + {"rs_getter", RsSafeName(field_name)}, {"type", rs_type}, }, R"rs(