diff --git a/src/array_serde.rs b/src/array_serde.rs index f1c9e1219..056fb3024 100644 --- a/src/array_serde.rs +++ b/src/array_serde.rs @@ -17,7 +17,7 @@ use alloc::vec::Vec; use crate::imp_prelude::*; use super::arraytraits::ARRAY_FORMAT_VERSION; -use super::Iter; +use super::iter::Iter; use crate::IntoDimension; /// Verifies that the version of the deserialized array matches the current diff --git a/src/dimension/mod.rs b/src/dimension/mod.rs index f66b83195..c425023f8 100644 --- a/src/dimension/mod.rs +++ b/src/dimension/mod.rs @@ -784,36 +784,6 @@ where } } -/// Move the axis which has the smallest absolute stride and a length -/// greater than one to be the last axis. -pub fn move_min_stride_axis_to_last(dim: &mut D, strides: &mut D) -where - D: Dimension, -{ - debug_assert_eq!(dim.ndim(), strides.ndim()); - match dim.ndim() { - 0 | 1 => {} - 2 => { - if dim[1] <= 1 - || dim[0] > 1 && (strides[0] as isize).abs() < (strides[1] as isize).abs() - { - dim.slice_mut().swap(0, 1); - strides.slice_mut().swap(0, 1); - } - } - n => { - if let Some(min_stride_axis) = (0..n) - .filter(|&ax| dim[ax] > 1) - .min_by_key(|&ax| (strides[ax] as isize).abs()) - { - let last = n - 1; - dim.slice_mut().swap(last, min_stride_axis); - strides.slice_mut().swap(last, min_stride_axis); - } - } - } -} - /// Remove axes with length one, except never removing the last axis. pub(crate) fn squeeze(dim: &mut D, strides: &mut D) where @@ -857,7 +827,9 @@ pub(crate) fn sort_axes_to_standard(dim: &mut D, strides: &mut D) where D: Dimension, { - debug_assert!(dim.ndim() > 1); + if dim.ndim() <= 1 { + return; + } debug_assert_eq!(dim.ndim(), strides.ndim()); // bubble sort axes let mut changed = true; @@ -865,6 +837,7 @@ where changed = false; for i in 0..dim.ndim() - 1 { // make sure higher stride axes sort before. + debug_assert!(strides.get_stride(Axis(i)) >= 0); if strides.get_stride(Axis(i)).abs() < strides.get_stride(Axis(i + 1)).abs() { changed = true; dim.slice_mut().swap(i, i + 1); diff --git a/src/impl_methods.rs b/src/impl_methods.rs index 745a8e60b..12af52dfb 100644 --- a/src/impl_methods.rs +++ b/src/impl_methods.rs @@ -19,7 +19,7 @@ use crate::argument_traits::AssignElem; use crate::dimension; use crate::dimension::IntoDimension; use crate::dimension::{ - abs_index, axes_of, do_slice, merge_axes, move_min_stride_axis_to_last, + abs_index, axes_of, do_slice, merge_axes, offset_from_low_addr_ptr_to_logical_ptr, size_of_shape_checked, stride_offset, Axes, }; use crate::dimension::broadcast::co_broadcast; @@ -433,7 +433,7 @@ where where S: Data, { - IndexedIter::new(self.view().into_elements_base()) + IndexedIter::new(self.view().into_elements_base_keep_dims()) } /// Return an iterator of indexes and mutable references to the elements of the array. @@ -446,7 +446,7 @@ where where S: DataMut, { - IndexedIterMut::new(self.view_mut().into_elements_base()) + IndexedIterMut::new(self.view_mut().into_elements_base_keep_dims()) } /// Return a sliced view of the array. @@ -2441,9 +2441,7 @@ where if let Some(slc) = self.as_slice_memory_order() { slc.iter().fold(init, f) } else { - let mut v = self.view(); - move_min_stride_axis_to_last(&mut v.dim, &mut v.strides); - v.into_elements_base().fold(init, f) + self.view().into_elements_base_any_order().fold(init, f) } } @@ -2599,9 +2597,7 @@ where match self.try_as_slice_memory_order_mut() { Ok(slc) => slc.iter_mut().for_each(f), Err(arr) => { - let mut v = arr.view_mut(); - move_min_stride_axis_to_last(&mut v.dim, &mut v.strides); - v.into_elements_base().for_each(f); + arr.view_mut().into_elements_base_any_order().for_each(f); } } } diff --git a/src/impl_views/conversions.rs b/src/impl_views/conversions.rs index 1e9a1499f..d29f1a794 100644 --- a/src/impl_views/conversions.rs +++ b/src/impl_views/conversions.rs @@ -12,10 +12,10 @@ use std::mem::MaybeUninit; use crate::imp_prelude::*; -use crate::{Baseiter, ElementsBase, ElementsBaseMut, Iter, IterMut}; - use crate::dimension::offset_from_low_addr_ptr_to_logical_ptr; -use crate::iter::{self, AxisIter, AxisIterMut}; +use crate::iter::{self, Iter, IterMut, AxisIter, AxisIterMut}; +use crate::iterators::base::{Baseiter, ElementsBase, ElementsBaseMut, OrderOption, PreserveOrder, + ArbitraryOrder, NoOptimization}; use crate::math_cell::MathCell; use crate::IndexLonger; @@ -188,14 +188,25 @@ impl<'a, A, D> ArrayView<'a, A, D> where D: Dimension, { + /// Create a base iter fromt the view with the given order option + #[inline] + pub(crate) fn into_base_iter(self) -> Baseiter { + unsafe { Baseiter::new_with_order::(self.ptr.as_ptr(), self.dim, self.strides) } + } + + #[inline] + pub(crate) fn into_elements_base_keep_dims(self) -> ElementsBase<'a, A, D> { + ElementsBase::new::(self) + } + #[inline] - pub(crate) fn into_base_iter(self) -> Baseiter { - unsafe { Baseiter::new(self.ptr.as_ptr(), self.dim, self.strides) } + pub(crate) fn into_elements_base_preserve_order(self) -> ElementsBase<'a, A, D> { + ElementsBase::new::(self) } #[inline] - pub(crate) fn into_elements_base(self) -> ElementsBase<'a, A, D> { - ElementsBase::new(self) + pub(crate) fn into_elements_base_any_order(self) -> ElementsBase<'a, A, D> { + ElementsBase::new::(self) } pub(crate) fn into_iter_(self) -> Iter<'a, A, D> { @@ -227,16 +238,28 @@ where unsafe { RawArrayViewMut::new(self.ptr, self.dim, self.strides) } } + /// Create a base iter fromt the view with the given order option #[inline] - pub(crate) fn into_base_iter(self) -> Baseiter { - unsafe { Baseiter::new(self.ptr.as_ptr(), self.dim, self.strides) } + pub(crate) fn into_base_iter(self) -> Baseiter { + unsafe { Baseiter::new_with_order::(self.ptr.as_ptr(), self.dim, self.strides) } } #[inline] - pub(crate) fn into_elements_base(self) -> ElementsBaseMut<'a, A, D> { - ElementsBaseMut::new(self) + pub(crate) fn into_elements_base_keep_dims(self) -> ElementsBaseMut<'a, A, D> { + ElementsBaseMut::new::(self) } + #[inline] + pub(crate) fn into_elements_base_preserve_order(self) -> ElementsBaseMut<'a, A, D> { + ElementsBaseMut::new::(self) + } + + #[inline] + pub(crate) fn into_elements_base_any_order(self) -> ElementsBaseMut<'a, A, D> { + ElementsBaseMut::new::(self) + } + + /// Return the array’s data as a slice, if it is contiguous and in standard order. /// Otherwise return self in the Err branch of the result. pub(crate) fn try_into_slice(self) -> Result<&'a mut [A], Self> { diff --git a/src/iterators/base.rs b/src/iterators/base.rs index a884b18b6..3d1a50dda 100644 --- a/src/iterators/base.rs +++ b/src/iterators/base.rs @@ -71,9 +71,7 @@ impl Baseiter { /// Return the iter strides pub(crate) fn raw_strides(&self) -> D { self.strides.clone() } -} -impl Baseiter { /// Creating a Baseiter is unsafe because shape and stride parameters need /// to be correct to avoid performing an unsafe pointer offset while /// iterating. @@ -252,9 +250,9 @@ clone_bounds!( ); impl<'a, A, D: Dimension> ElementsBase<'a, A, D> { - pub fn new(v: ArrayView<'a, A, D>) -> Self { + pub fn new(v: ArrayView<'a, A, D>) -> Self { ElementsBase { - inner: v.into_base_iter(), + inner: v.into_base_iter::(), life: PhantomData, } } @@ -338,7 +336,7 @@ where inner: if let Some(slc) = self_.to_slice() { ElementsRepr::Slice(slc.iter()) } else { - ElementsRepr::Counted(self_.into_elements_base()) + ElementsRepr::Counted(self_.into_elements_base_preserve_order()) }, } } @@ -352,7 +350,7 @@ where IterMut { inner: match self_.try_into_slice() { Ok(x) => ElementsRepr::Slice(x.iter_mut()), - Err(self_) => ElementsRepr::Counted(self_.into_elements_base()), + Err(self_) => ElementsRepr::Counted(self_.into_elements_base_preserve_order()), }, } } @@ -397,9 +395,9 @@ pub(crate) struct ElementsBaseMut<'a, A, D> { } impl<'a, A, D: Dimension> ElementsBaseMut<'a, A, D> { - pub fn new(v: ArrayViewMut<'a, A, D>) -> Self { + pub fn new(v: ArrayViewMut<'a, A, D>) -> Self { ElementsBaseMut { - inner: v.into_base_iter(), + inner: v.into_base_iter::(), life: PhantomData, } } diff --git a/src/iterators/chunks.rs b/src/iterators/chunks.rs index ba0e789fb..894376f5a 100644 --- a/src/iterators/chunks.rs +++ b/src/iterators/chunks.rs @@ -79,7 +79,7 @@ where type IntoIter = ExactChunksIter<'a, A, D>; fn into_iter(self) -> Self::IntoIter { ExactChunksIter { - iter: self.base.into_elements_base(), + iter: self.base.into_elements_base_any_order(), chunk: self.chunk, inner_strides: self.inner_strides, } @@ -169,7 +169,7 @@ where type IntoIter = ExactChunksIterMut<'a, A, D>; fn into_iter(self) -> Self::IntoIter { ExactChunksIterMut { - iter: self.base.into_elements_base(), + iter: self.base.into_elements_base_any_order(), chunk: self.chunk, inner_strides: self.inner_strides, } diff --git a/src/iterators/lanes.rs b/src/iterators/lanes.rs index 8cc212e8d..74fdcd4f9 100644 --- a/src/iterators/lanes.rs +++ b/src/iterators/lanes.rs @@ -3,6 +3,7 @@ use std::marker::PhantomData; use crate::imp_prelude::*; use crate::{Layout, NdProducer}; use crate::iterators::Baseiter; +use crate::iterators::base::NoOptimization; impl_ndproducer! { ['a, A, D: Dimension] @@ -83,7 +84,7 @@ where type IntoIter = LanesIter<'a, A, D>; fn into_iter(self) -> Self::IntoIter { LanesIter { - iter: self.base.into_base_iter(), + iter: self.base.into_base_iter::(), inner_len: self.inner_len, inner_stride: self.inner_stride, life: PhantomData, @@ -134,7 +135,7 @@ where type IntoIter = LanesIterMut<'a, A, D>; fn into_iter(self) -> Self::IntoIter { LanesIterMut { - iter: self.base.into_base_iter(), + iter: self.base.into_base_iter::(), inner_len: self.inner_len, inner_stride: self.inner_stride, life: PhantomData, diff --git a/src/iterators/mod.rs b/src/iterators/mod.rs index 846b29491..9a13fde03 100644 --- a/src/iterators/mod.rs +++ b/src/iterators/mod.rs @@ -10,7 +10,7 @@ mod macros; mod axis; -mod base; +pub(crate) mod base; mod chunks; mod into_iter; pub mod iter; diff --git a/src/iterators/windows.rs b/src/iterators/windows.rs index c47bfecec..3b8f9de63 100644 --- a/src/iterators/windows.rs +++ b/src/iterators/windows.rs @@ -77,7 +77,7 @@ where type IntoIter = WindowsIter<'a, A, D>; fn into_iter(self) -> Self::IntoIter { WindowsIter { - iter: self.base.into_elements_base(), + iter: self.base.into_elements_base_preserve_order(), window: self.window, strides: self.strides, } diff --git a/src/lib.rs b/src/lib.rs index af058f518..545d60e38 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -142,8 +142,7 @@ pub use crate::slice::{ MultiSliceArg, NewAxis, Slice, SliceArg, SliceInfo, SliceInfoElem, SliceNextDim, }; -use crate::iterators::Baseiter; -use crate::iterators::{ElementsBase, ElementsBaseMut, Iter, IterMut}; +use crate::iterators::{ElementsBase, ElementsBaseMut}; pub use crate::arraytraits::AsArray; #[cfg(feature = "std")]