From ba79df22831bfd711b60cbfaae206d8332119628 Mon Sep 17 00:00:00 2001 From: Rouven Spreckels Date: Thu, 13 Jun 2024 07:47:37 +0200 Subject: [PATCH] Spill recursion stack over to heap if necessary. --- .github/workflows/build.yml | 2 +- Cargo.toml | 7 ++++--- README.md | 1 + RELEASES.md | 4 ++++ src/lib.rs | 13 +++++++++++++ src/par/merge_sort.rs | 2 +- src/par/partition.rs | 39 ++++++++++++++++++++++++++----------- src/par/quick_sort.rs | 9 +++++---- src/partition.rs | 4 ++-- src/quick_sort.rs | 5 +++-- 10 files changed, 62 insertions(+), 24 deletions(-) diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 156220b..50f3a7b 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -84,4 +84,4 @@ jobs: - name: fmt run: cargo fmt --check - name: miri - run: cargo miri test -- Slice1Ext + run: cargo miri test --no-default-features --features std -- Slice1Ext diff --git a/Cargo.toml b/Cargo.toml index d7ac748..233dda6 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "ndarray-slice" -version = "0.3.0" +version = "0.3.1" rust-version = "1.65.0" edition = "2021" authors = ["Rouven Spreckels "] @@ -37,7 +37,8 @@ rustdoc-args = ["--cfg", "docsrs"] [dependencies] ndarray = { version = "0.15.6", default-features = false } -rayon = { version = "1.9.0", optional = true } +stacker = { version = "0.1.15", optional = true } +rayon = { version = "1.10.0", optional = true } [dev-dependencies] quickcheck = "1.0.3" @@ -45,7 +46,7 @@ quickcheck_macros = "1.0.0" rand = "0.8.5" [features] -default = ["std"] +default = ["std", "stacker"] alloc = [] std = ["alloc", "ndarray/std"] rayon = ["dep:rayon", "ndarray/rayon", "std"] diff --git a/README.md b/README.md index 21f0cba..02dcb87 100644 --- a/README.md +++ b/README.md @@ -90,6 +90,7 @@ See the [release history](RELEASES.md) to keep track of the development. * `alloc` for stable `sort`/`sort_by`/`sort_by_key`. Enabled by `std`. * `std` for stable `sort_by_cached_key`. Enabled by `default` or `rayon`. + * `stacker` for spilling recursion stack over to heap if necessary. Enabled by `default`. * `rayon` for parallel `par_sort*`/`par_select_many_nth_unstable*`. # License diff --git a/RELEASES.md b/RELEASES.md index c3c4326..3fb8ca3 100644 --- a/RELEASES.md +++ b/RELEASES.md @@ -1,3 +1,7 @@ +# Version 0.3.1 (2024-06-13) + + * Spill recursion stack over to heap if necessary. + # Version 0.3.0 (2024-03-19) * Synchronize with Rust standard library. diff --git a/src/lib.rs b/src/lib.rs index 670590f..c25ea68 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -70,6 +70,7 @@ //! //! * `alloc` for stable `sort`/`sort_by`/`sort_by_key`. Enabled by `std`. //! * `std` for stable `sort_by_cached_key`. Enabled by `default` or `rayon`. +//! * `stacker` for spilling recursion stack over to heap if necessary. Enabled by `default`. //! * `rayon` for parallel `par_sort*`/`par_select_many_nth_unstable*`. #![deny( @@ -81,6 +82,18 @@ #![cfg_attr(docsrs, feature(doc_auto_cfg))] #![cfg_attr(miri, feature(strict_provenance), feature(maybe_uninit_slice))] +#[inline(always)] +fn maybe_grow R>(callback: F) -> R { + #[cfg(feature = "stacker")] + { + stacker::maybe_grow(32 * 1_024, 1_024 * 1_024, callback) + } + #[cfg(not(feature = "stacker"))] + { + callback() + } +} + mod heap_sort; mod insertion_sort; mod merge_sort; diff --git a/src/par/merge_sort.rs b/src/par/merge_sort.rs index 3175eef..faedb14 100644 --- a/src/par/merge_sort.rs +++ b/src/par/merge_sort.rs @@ -36,7 +36,7 @@ impl SendPtr { // Implement Clone without the T: Clone bound from the derive impl Clone for SendPtr { fn clone(&self) -> Self { - Self(self.0) + *self } } diff --git a/src/par/partition.rs b/src/par/partition.rs index d9a64fe..91e9420 100644 --- a/src/par/partition.rs +++ b/src/par/partition.rs @@ -4,6 +4,7 @@ use crate::{ insertion_sort::InsertionHole, + maybe_grow, par::insertion_sort::insertion_sort_shift_left, partition::{break_patterns, reverse}, }; @@ -48,15 +49,31 @@ pub fn par_partition_at_indices<'a, T, F>( let (left_values, right_values) = values.split_at_mut(at); let right_values = &mut right_values[1..]; if at == 0 || pivot - offset <= MAX_SEQUENTIAL { - par_partition_at_indices(left, offset, left_indices, left_values, is_less); + maybe_grow(|| { + par_partition_at_indices(left, offset, left_indices, left_values, is_less) + }); v = right; offset = pivot + 1; indices = right_indices; values = right_values; } else { rayon::join( - || par_partition_at_indices(left, offset, left_indices, left_values, is_less), - || par_partition_at_indices(right, pivot + 1, right_indices, right_values, is_less), + || { + maybe_grow(|| { + par_partition_at_indices(left, offset, left_indices, left_values, is_less) + }) + }, + || { + maybe_grow(|| { + par_partition_at_indices( + right, + pivot + 1, + right_indices, + right_values, + is_less, + ) + }) + }, ); break; } @@ -641,15 +658,15 @@ where if count > 0 { macro_rules! left { - () => { - v.view_mut().index(l + usize::from(*start_l)) as *mut T //l.add(usize::from(*start_l)) - }; - } + () => { + v.view_mut().index(l + usize::from(*start_l)) as *mut T //l.add(usize::from(*start_l)) + }; + } macro_rules! right { - () => { - v.view_mut().index(r - (usize::from(*start_r) + 1)) as *mut T //r.sub(usize::from(*start_r) + 1) - }; - } + () => { + v.view_mut().index(r - (usize::from(*start_r) + 1)) as *mut T //r.sub(usize::from(*start_r) + 1) + }; + } // Instead of swapping one pair at the time, it is more efficient to perform a cyclic // permutation. This is not strictly equivalent to swapping, but produces a similar diff --git a/src/par/quick_sort.rs b/src/par/quick_sort.rs index 2688bb2..7f897a5 100644 --- a/src/par/quick_sort.rs +++ b/src/par/quick_sort.rs @@ -3,6 +3,7 @@ //! [`rayon::slice::quicksort`]: https://docs.rs/rayon/latest/src/rayon/slice/quicksort.rs.html use crate::{ + maybe_grow, par::{ heap_sort::heap_sort, insertion_sort::{insertion_sort_shift_left, partial_insertion_sort}, @@ -123,18 +124,18 @@ fn recurse<'a, T, F>( // calls and consume less stack space. Then just continue with the longer side (this is // akin to tail recursion). if left.len() < right.len() { - recurse(left, is_less, pred, limit); + maybe_grow(|| recurse(left, is_less, pred, limit)); v = right; pred = Some(pivot); } else { - recurse(right, is_less, Some(pivot), limit); + maybe_grow(|| recurse(right, is_less, Some(pivot), limit)); v = left; } } else { // Sort the left and right half in parallel. rayon::join( - || recurse(left, is_less, pred, limit), - || recurse(right, is_less, Some(pivot), limit), + || maybe_grow(|| recurse(left, is_less, pred, limit)), + || maybe_grow(|| recurse(right, is_less, Some(pivot), limit)), ); break; } diff --git a/src/partition.rs b/src/partition.rs index e503b06..65a6f23 100644 --- a/src/partition.rs +++ b/src/partition.rs @@ -2,7 +2,7 @@ //! //! [`core::slice::sort`]: https://doc.rust-lang.org/src/core/slice/sort.rs.html -use crate::insertion_sort::insertion_sort_shift_left; +use crate::{insertion_sort::insertion_sort_shift_left, maybe_grow}; use core::{ cmp::{ self, @@ -33,7 +33,7 @@ pub fn partition_at_indices<'a, T, E, F>( let (index, right_indices) = right_indices.split_at(Axis(0), 1); let pivot = *index.index(0); let (left, value, right) = partition_at_index(v, pivot - offset, is_less); - partition_at_indices(left, offset, left_indices, collection, is_less); + maybe_grow(|| partition_at_indices(left, offset, left_indices, collection, is_less)); collection.extend([(pivot, value)]); v = right; offset = pivot + 1; diff --git a/src/quick_sort.rs b/src/quick_sort.rs index 1c811a6..a5c3971 100644 --- a/src/quick_sort.rs +++ b/src/quick_sort.rs @@ -6,6 +6,7 @@ use crate::{ heap_sort::heap_sort, insertion_sort::insertion_sort_shift_left, insertion_sort::partial_insertion_sort, + maybe_grow, partition::{break_patterns, choose_pivot, partition, partition_equal}, }; use core::{cmp, mem}; @@ -115,11 +116,11 @@ fn recurse<'a, T, F>( // calls and consume less stack space. Then just continue with the longer side (this is // akin to tail recursion). if left.len() < right.len() { - recurse(left, is_less, pred, limit); + maybe_grow(|| recurse(left, is_less, pred, limit)); v = right; pred = Some(pivot); } else { - recurse(right, is_less, Some(pivot), limit); + maybe_grow(|| recurse(right, is_less, Some(pivot), limit)); v = left; } }