diff --git a/src/map.rs b/src/map.rs index 5049aa2b5b..d5f34ba6c3 100644 --- a/src/map.rs +++ b/src/map.rs @@ -1730,9 +1730,15 @@ where where Q: Hash + Equivalent, { - let hashes = self.build_hashes_inner(ks); - self.table - .get_many_mut(hashes, |i, (k, _)| ks[i].equivalent(k)) + let hash_builder = &self.hash_builder; + + let mut iter = ks.into_iter().map(|key| { + ( + make_hash::(hash_builder, key), + equivalent_key::(key), + ) + }); + self.table.get_many_mut_from_iter(&mut iter) } unsafe fn get_many_unchecked_mut_inner( @@ -1742,20 +1748,15 @@ where where Q: Hash + Equivalent, { - let hashes = self.build_hashes_inner(ks); - self.table - .get_many_unchecked_mut(hashes, |i, (k, _)| ks[i].equivalent(k)) - } + let hash_builder = &self.hash_builder; - fn build_hashes_inner(&self, ks: [&Q; N]) -> [u64; N] - where - Q: Hash + Equivalent, - { - let mut hashes = [0_u64; N]; - for i in 0..N { - hashes[i] = make_hash::(&self.hash_builder, ks[i]); - } - hashes + let mut iter = ks.into_iter().map(|key| { + ( + make_hash::(hash_builder, key), + equivalent_key::(key), + ) + }); + self.table.get_many_unchecked_mut_from_iter(&mut iter) } /// Inserts a key-value pair into the map. diff --git a/src/raw/mod.rs b/src/raw/mod.rs index 24021ac4a3..118c0f6f33 100644 --- a/src/raw/mod.rs +++ b/src/raw/mod.rs @@ -865,6 +865,7 @@ impl RawTable { /// /// The `eq` argument should be a closure such that `eq(i, k)` returns true if `k` is equal to /// the `i`th key to be looked up. + #[cfg(feature = "raw")] pub fn get_many_mut( &mut self, hashes: [u64; N], @@ -886,6 +887,7 @@ impl RawTable { } } + #[cfg(feature = "raw")] pub unsafe fn get_many_unchecked_mut( &mut self, hashes: [u64; N], @@ -895,6 +897,7 @@ impl RawTable { Some(mem::transmute_copy(&ptrs)) } + #[cfg(feature = "raw")] unsafe fn get_many_mut_pointers( &mut self, hashes: [u64; N], @@ -913,6 +916,157 @@ impl RawTable { Some(outs.assume_init()) } + /// Attempts to get mutable references to `N element` in the table at once using + /// `hash` and equality function from iterator. + /// + /// The `iter` argument should be an iterator that return `hash` of the stored + /// `element` and closure for checking the equivalence of that `element`. + /// + /// This function return `None`: + /// + /// - if an `element` is not found for any item from the iterator; + /// - if any of the requested `elements` from table are duplicated; + /// - if the given `const N` is equal to zero (`0`). + /// - if the given iterator length is not equal to the specified `const N`; + #[allow(clippy::explicit_counter_loop)] + pub fn get_many_mut_from_iter<'a, 'b, I, F, const N: usize>( + &'a mut self, + iter: &'b mut I, + ) -> Option<[&'a mut T; N]> + where + I: Iterator, + F: FnMut(&T) -> bool, + { + let pointers: [*mut T; N] = self.get_many_mut_pointers_from_iter(iter)?; + + // Avoid using `Iterator::enumerate` because of double checking + let mut index = 0_usize; + for ¤t in pointers.iter() { + // SAFETY: we now exactly that the `index` less than `pointers` length + if unsafe { pointers.get_unchecked(..index) } + .iter() + .any(|&previous| previous == current) + { + return None; + } + index += 1; + } + + // SAFETY: All bucket are distinct from all previous buckets, `*mut T` and `&T` + // are guaranteed properly aligned and have the same layout, so we're clear to + // return the result of the lookup. + // Also no needance of using mem::forget(pointers) because it is just array of + // pointers. + Some(unsafe { (&pointers as *const _ as *const [&mut T; N]).read() }) + } + + /// Attempts to get mutable references to `N element` in the table at once using + /// `hash` and equality function from iterator, without checking the uniqueness + /// of the found elements. + /// + /// The `iter` argument should be an iterator that return `hash` of the stored + /// `element` and closure for checking the equivalence of that `element`. + /// + /// This function return `None`: + /// + /// - if an `element` is not found for any item from the iterator; + /// - if the given `const N` is equal to zero (`0`). + /// - if the given iterator length is not equal to the specified `const N`; + /// + /// # Safety + /// + /// Calling this method is *[undefined behavior]* if iterator contain overlapping + /// items that refer to the same `elements` in the table even if the resulting + /// references to `elements` in the table are not used. + /// + /// [undefined behavior]: https://doc.rust-lang.org/reference/behavior-considered-undefined.html + pub unsafe fn get_many_unchecked_mut_from_iter<'a, 'b, I, F, const N: usize>( + &'a mut self, + iter: &'b mut I, + ) -> Option<[&'a mut T; N]> + where + I: Iterator, + F: FnMut(&T) -> bool, + { + let pointers: [*mut T; N] = self.get_many_mut_pointers_from_iter(iter)?; + + // SAFETY: the caller must uphold the safety contract for `get_many_unchecked_mut_from_iter`. + // We only know that `*mut T` and `&T` are guaranteed properly aligned and have the same layout. + // Also we know that there is no needance of using mem::forget(pointers) because it is just + // array of pointers. + Some((&pointers as *const _ as *const [&mut T; N]).read()) + } + + /// Attempts to get mutable pointers to `N element` in the table at once using + /// `hash` and equality function from iterator, without checking the uniqueness + /// of the found elements. + /// + /// The `iter` argument should be an iterator that return `hash` of the stored + /// `element` and closure for checking the equivalence of that `element`. + /// + /// This function return `None`: + /// + /// - if an `element` is not found for any item from the iterator; + /// - if the given `const N` is equal to zero (`0`). + /// - if the given iterator length is not equal to the specified `const N`; + /// + /// # Safety + /// + /// Calling this method is safe, but the returned array may contain overlapping + /// items pointing to the same `elements` in the table. + fn get_many_mut_pointers_from_iter( + &mut self, + iter: &mut I, + ) -> Option<[*mut T; N]> + where + I: Iterator, + F: FnMut(&T) -> bool, + { + // Check trivial cases + if N == 0 || N > self.len() { + return None; + } + + // If `iterator::size_hint` returns some upper bound, we check + // that it is equal to `const N`, else return from the function + if let (_, Some(upper_bound)) = iter.size_hint() { + if upper_bound != N { + return None; + } + } + + // SAFETY: An uninitialized `[MaybeUninit<_>; LEN]` is valid, + // because the type we are claiming to have initialized here is a + // bunch of `MaybeUninit`s, which do not require initialization. + // + // FIXME: Use `MaybeUninit::uninit_array` or `maybe_uninit_uninit_array_transpose` + // (https://github.com/rust-lang/rust/pull/102023) instead as soon as either becomes + // stable + let mut array = unsafe { MaybeUninit::<[MaybeUninit<*mut T>; N]>::uninit().assume_init() }; + + for element in &mut array { + match iter.next() { + Some((hash, eq)) => match self.find(hash, eq) { + Some(bucket) => { + element.write(bucket.as_ptr()); + } + None => return None, + }, + None => return None, + } + } + // SAFETY: All elements of the array were populated in the loop above, + // `MaybeUninit<*mut T>` and `*mut T` are guaranteed properly aligned and + // have the same layout. + // Also no needance of using mem::forget(array) because it is just array of + // pointers. + // + // FIXME: Use `MaybeUninit::array_assume_init` or `maybe_uninit_uninit_array_transpose` + // (https://github.com/rust-lang/rust/pull/102023) instead as soon as either becomes + // stable + Some(unsafe { (&array as *const _ as *const [*mut T; N]).read() }) + } + /// Returns the number of elements the map can hold without reallocating. /// /// This number is a lower bound; the table might be able to hold