diff --git a/rust/cpp.rs b/rust/cpp.rs index e79c9ba902..77e0454553 100644 --- a/rust/cpp.rs +++ b/rust/cpp.rs @@ -415,7 +415,7 @@ macro_rules! impl_ProxiedInMapValue_for_non_generated_value_types { fn [< __rust_proto_thunk__Map_ $key_t _ $t _free >](m: RawMap); fn [< __rust_proto_thunk__Map_ $key_t _ $t _clear >](m: RawMap); fn [< __rust_proto_thunk__Map_ $key_t _ $t _size >](m: RawMap) -> usize; - fn [< __rust_proto_thunk__Map_ $key_t _ $t _insert >](m: RawMap, key: $ffi_key_t, value: $ffi_t); + fn [< __rust_proto_thunk__Map_ $key_t _ $t _insert >](m: RawMap, key: $ffi_key_t, value: $ffi_t) -> bool; fn [< __rust_proto_thunk__Map_ $key_t _ $t _get >](m: RawMap, key: $ffi_key_t, value: *mut $ffi_t) -> bool; fn [< __rust_proto_thunk__Map_ $key_t _ $t _remove >](m: RawMap, key: $ffi_key_t, value: *mut $ffi_t) -> bool; } @@ -453,7 +453,6 @@ macro_rules! impl_ProxiedInMapValue_for_non_generated_value_types { let ffi_key = $to_ffi_key(key); let ffi_value = $to_ffi_value(value); unsafe { [< __rust_proto_thunk__Map_ $key_t _ $t _insert >](map.inner.raw, ffi_key, ffi_value) } - true } fn map_get<'a>(map: View<'a, Map<$key_t, Self>>, key: View<'_, $key_t>) -> Option> { diff --git a/rust/cpp_kernel/cpp_api.cc b/rust/cpp_kernel/cpp_api.cc index 404f6f60a6..0666308182 100644 --- a/rust/cpp_kernel/cpp_api.cc +++ b/rust/cpp_kernel/cpp_api.cc @@ -115,11 +115,10 @@ expose_repeated_ptr_field_methods(Bytes); const google::protobuf::Map* m) { \ return m->size(); \ } \ - void __rust_proto_thunk__Map_##rust_key_ty##_##rust_value_ty##_insert( \ + bool __rust_proto_thunk__Map_##rust_key_ty##_##rust_value_ty##_insert( \ google::protobuf::Map* m, ffi_key_ty key, ffi_value_ty value) { \ - auto cpp_key = to_cpp_key; \ - auto cpp_value = to_cpp_value; \ - (*m)[cpp_key] = cpp_value; \ + auto iter_and_inserted = m->try_emplace(to_cpp_key, to_cpp_value); \ + return iter_and_inserted.second; \ } \ bool __rust_proto_thunk__Map_##rust_key_ty##_##rust_value_ty##_get( \ const google::protobuf::Map* m, ffi_key_ty key, \ diff --git a/rust/map.rs b/rust/map.rs index 43f59c3c05..d7368f80ca 100644 --- a/rust/map.rs +++ b/rust/map.rs @@ -231,6 +231,9 @@ where self.len() == 0 } + /// Adds a key-value pair to the map. + /// + /// Returns `true` if the entry was newly inserted. pub fn insert<'a, 'b>( &mut self, key: impl Into>, diff --git a/rust/test/shared/accessors_map_test.rs b/rust/test/shared/accessors_map_test.rs index 19c020ad3a..85f9fb5c4c 100644 --- a/rust/test/shared/accessors_map_test.rs +++ b/rust/test/shared/accessors_map_test.rs @@ -20,6 +20,7 @@ macro_rules! generate_map_primitives_tests { let k: $k_type = Default::default(); let v: $v_type = Default::default(); assert_that!(msg.[< map_ $k_field _ $v_field _mut>]().insert(k, v), eq(true)); + assert_that!(msg.[< map_ $k_field _ $v_field _mut>]().insert(k, v), eq(false)); assert_that!(msg.[< map_ $k_field _ $v_field >]().len(), eq(1)); } )* } diff --git a/rust/upb.rs b/rust/upb.rs index 6a1c65e169..b28bc01a0e 100644 --- a/rust/upb.rs +++ b/rust/upb.rs @@ -832,7 +832,7 @@ macro_rules! impl_ProxiedInMapValue_for_non_generated_value_types { fn map_insert(map: Mut<'_, Map<$key_t, Self>>, key: View<'_, $key_t>, value: View<'_, Self>) -> bool { unsafe { - upb_Map_Set( + upb_Map_InsertAndReturnIfInserted( map.inner.raw, <$key_t as UpbTypeConversions>::to_message_value(key), <$t as UpbTypeConversions>::to_message_value_copy_if_required(map.inner.raw_arena, value), @@ -878,15 +878,49 @@ macro_rules! impl_ProxiedInMapValue_for_key_types { impl_ProxiedInMapValue_for_key_types!(i32, u32, i64, u64, bool, ProtoStr); +#[repr(C)] +#[allow(dead_code)] +enum upb_MapInsertStatus { + Inserted = 0, + Replaced = 1, + OutOfMemory = 2, +} + +/// `upb_Map_Insert`, but returns a `bool` for whether insert occurred. +/// +/// Returns `true` if the entry was newly inserted. +/// +/// # Panics +/// Panics if the arena is out of memory. +/// +/// # Safety +/// The same as `upb_Map_Insert`: +/// - `map` must be a valid map. +/// - The `arena` must be valid and outlive the map. +/// - The inserted value must outlive the map. +#[allow(non_snake_case)] +pub unsafe fn upb_Map_InsertAndReturnIfInserted( + map: RawMap, + key: upb_MessageValue, + value: upb_MessageValue, + arena: RawArena, +) -> bool { + match unsafe { upb_Map_Insert(map, key, value, arena) } { + upb_MapInsertStatus::Inserted => true, + upb_MapInsertStatus::Replaced => false, + upb_MapInsertStatus::OutOfMemory => panic!("map arena is out of memory"), + } +} + extern "C" { fn upb_Map_New(arena: RawArena, key_type: UpbCType, value_type: UpbCType) -> RawMap; fn upb_Map_Size(map: RawMap) -> usize; - fn upb_Map_Set( + fn upb_Map_Insert( map: RawMap, key: upb_MessageValue, value: upb_MessageValue, arena: RawArena, - ) -> bool; + ) -> upb_MapInsertStatus; fn upb_Map_Get(map: RawMap, key: upb_MessageValue, value: *mut upb_MessageValue) -> bool; fn upb_Map_Delete( map: RawMap,