Skip to content

Commit

Permalink
FEAT: Use Baseiter optimizations in some places where it's possible
Browse files Browse the repository at this point in the history
  • Loading branch information
bluss committed Dec 6, 2021
1 parent be87fe7 commit f31add8
Show file tree
Hide file tree
Showing 10 changed files with 58 additions and 68 deletions.
2 changes: 1 addition & 1 deletion src/array_serde.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ use alloc::vec::Vec;
use crate::imp_prelude::*;

use super::arraytraits::ARRAY_FORMAT_VERSION;
use super::Iter;
use super::iter::Iter;
use crate::IntoDimension;

/// Verifies that the version of the deserialized array matches the current
Expand Down
35 changes: 4 additions & 31 deletions src/dimension/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -784,36 +784,6 @@ where
}
}

/// Move the axis which has the smallest absolute stride and a length
/// greater than one to be the last axis.
pub fn move_min_stride_axis_to_last<D>(dim: &mut D, strides: &mut D)
where
D: Dimension,
{
debug_assert_eq!(dim.ndim(), strides.ndim());
match dim.ndim() {
0 | 1 => {}
2 => {
if dim[1] <= 1
|| dim[0] > 1 && (strides[0] as isize).abs() < (strides[1] as isize).abs()
{
dim.slice_mut().swap(0, 1);
strides.slice_mut().swap(0, 1);
}
}
n => {
if let Some(min_stride_axis) = (0..n)
.filter(|&ax| dim[ax] > 1)
.min_by_key(|&ax| (strides[ax] as isize).abs())
{
let last = n - 1;
dim.slice_mut().swap(last, min_stride_axis);
strides.slice_mut().swap(last, min_stride_axis);
}
}
}
}

/// Remove axes with length one, except never removing the last axis.
pub(crate) fn squeeze<D>(dim: &mut D, strides: &mut D)
where
Expand Down Expand Up @@ -857,14 +827,17 @@ pub(crate) fn sort_axes_to_standard<D>(dim: &mut D, strides: &mut D)
where
D: Dimension,
{
debug_assert!(dim.ndim() > 1);
if dim.ndim() <= 1 {
return;
}
debug_assert_eq!(dim.ndim(), strides.ndim());
// bubble sort axes
let mut changed = true;
while changed {
changed = false;
for i in 0..dim.ndim() - 1 {
// make sure higher stride axes sort before.
debug_assert!(strides.get_stride(Axis(i)) >= 0);
if strides.get_stride(Axis(i)).abs() < strides.get_stride(Axis(i + 1)).abs() {
changed = true;
dim.slice_mut().swap(i, i + 1);
Expand Down
14 changes: 5 additions & 9 deletions src/impl_methods.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ use crate::argument_traits::AssignElem;
use crate::dimension;
use crate::dimension::IntoDimension;
use crate::dimension::{
abs_index, axes_of, do_slice, merge_axes, move_min_stride_axis_to_last,
abs_index, axes_of, do_slice, merge_axes,
offset_from_low_addr_ptr_to_logical_ptr, size_of_shape_checked, stride_offset, Axes,
};
use crate::dimension::broadcast::co_broadcast;
Expand Down Expand Up @@ -433,7 +433,7 @@ where
where
S: Data,
{
IndexedIter::new(self.view().into_elements_base())
IndexedIter::new(self.view().into_elements_base_keep_dims())
}

/// Return an iterator of indexes and mutable references to the elements of the array.
Expand All @@ -446,7 +446,7 @@ where
where
S: DataMut,
{
IndexedIterMut::new(self.view_mut().into_elements_base())
IndexedIterMut::new(self.view_mut().into_elements_base_keep_dims())
}

/// Return a sliced view of the array.
Expand Down Expand Up @@ -2441,9 +2441,7 @@ where
if let Some(slc) = self.as_slice_memory_order() {
slc.iter().fold(init, f)
} else {
let mut v = self.view();
move_min_stride_axis_to_last(&mut v.dim, &mut v.strides);
v.into_elements_base().fold(init, f)
self.view().into_elements_base_any_order().fold(init, f)
}
}

Expand Down Expand Up @@ -2599,9 +2597,7 @@ where
match self.try_as_slice_memory_order_mut() {
Ok(slc) => slc.iter_mut().for_each(f),
Err(arr) => {
let mut v = arr.view_mut();
move_min_stride_axis_to_last(&mut v.dim, &mut v.strides);
v.into_elements_base().for_each(f);
arr.view_mut().into_elements_base_any_order().for_each(f);
}
}
}
Expand Down
45 changes: 34 additions & 11 deletions src/impl_views/conversions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,10 @@ use std::mem::MaybeUninit;

use crate::imp_prelude::*;

use crate::{Baseiter, ElementsBase, ElementsBaseMut, Iter, IterMut};

use crate::dimension::offset_from_low_addr_ptr_to_logical_ptr;
use crate::iter::{self, AxisIter, AxisIterMut};
use crate::iter::{self, Iter, IterMut, AxisIter, AxisIterMut};
use crate::iterators::base::{Baseiter, ElementsBase, ElementsBaseMut, OrderOption, PreserveOrder,
ArbitraryOrder, NoOptimization};
use crate::math_cell::MathCell;
use crate::IndexLonger;

Expand Down Expand Up @@ -188,14 +188,25 @@ impl<'a, A, D> ArrayView<'a, A, D>
where
D: Dimension,
{
/// Create a base iter fromt the view with the given order option
#[inline]
pub(crate) fn into_base_iter<F: OrderOption>(self) -> Baseiter<A, D> {
unsafe { Baseiter::new_with_order::<F>(self.ptr.as_ptr(), self.dim, self.strides) }
}

#[inline]
pub(crate) fn into_elements_base_keep_dims(self) -> ElementsBase<'a, A, D> {
ElementsBase::new::<NoOptimization>(self)
}

#[inline]
pub(crate) fn into_base_iter(self) -> Baseiter<A, D> {
unsafe { Baseiter::new(self.ptr.as_ptr(), self.dim, self.strides) }
pub(crate) fn into_elements_base_preserve_order(self) -> ElementsBase<'a, A, D> {
ElementsBase::new::<PreserveOrder>(self)
}

#[inline]
pub(crate) fn into_elements_base(self) -> ElementsBase<'a, A, D> {
ElementsBase::new(self)
pub(crate) fn into_elements_base_any_order(self) -> ElementsBase<'a, A, D> {
ElementsBase::new::<ArbitraryOrder>(self)
}

pub(crate) fn into_iter_(self) -> Iter<'a, A, D> {
Expand Down Expand Up @@ -227,16 +238,28 @@ where
unsafe { RawArrayViewMut::new(self.ptr, self.dim, self.strides) }
}

/// Create a base iter fromt the view with the given order option
#[inline]
pub(crate) fn into_base_iter(self) -> Baseiter<A, D> {
unsafe { Baseiter::new(self.ptr.as_ptr(), self.dim, self.strides) }
pub(crate) fn into_base_iter<F: OrderOption>(self) -> Baseiter<A, D> {
unsafe { Baseiter::new_with_order::<F>(self.ptr.as_ptr(), self.dim, self.strides) }
}

#[inline]
pub(crate) fn into_elements_base(self) -> ElementsBaseMut<'a, A, D> {
ElementsBaseMut::new(self)
pub(crate) fn into_elements_base_keep_dims(self) -> ElementsBaseMut<'a, A, D> {
ElementsBaseMut::new::<NoOptimization>(self)
}

#[inline]
pub(crate) fn into_elements_base_preserve_order(self) -> ElementsBaseMut<'a, A, D> {
ElementsBaseMut::new::<PreserveOrder>(self)
}

#[inline]
pub(crate) fn into_elements_base_any_order(self) -> ElementsBaseMut<'a, A, D> {
ElementsBaseMut::new::<ArbitraryOrder>(self)
}


/// Return the array’s data as a slice, if it is contiguous and in standard order.
/// Otherwise return self in the Err branch of the result.
pub(crate) fn try_into_slice(self) -> Result<&'a mut [A], Self> {
Expand Down
14 changes: 6 additions & 8 deletions src/iterators/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -71,9 +71,7 @@ impl<A, D: Dimension> Baseiter<A, D> {

/// Return the iter strides
pub(crate) fn raw_strides(&self) -> D { self.strides.clone() }
}

impl<A, D: Dimension> Baseiter<A, D> {
/// Creating a Baseiter is unsafe because shape and stride parameters need
/// to be correct to avoid performing an unsafe pointer offset while
/// iterating.
Expand Down Expand Up @@ -252,9 +250,9 @@ clone_bounds!(
);

impl<'a, A, D: Dimension> ElementsBase<'a, A, D> {
pub fn new(v: ArrayView<'a, A, D>) -> Self {
pub fn new<F: OrderOption>(v: ArrayView<'a, A, D>) -> Self {
ElementsBase {
inner: v.into_base_iter(),
inner: v.into_base_iter::<F>(),
life: PhantomData,
}
}
Expand Down Expand Up @@ -338,7 +336,7 @@ where
inner: if let Some(slc) = self_.to_slice() {
ElementsRepr::Slice(slc.iter())
} else {
ElementsRepr::Counted(self_.into_elements_base())
ElementsRepr::Counted(self_.into_elements_base_preserve_order())
},
}
}
Expand All @@ -352,7 +350,7 @@ where
IterMut {
inner: match self_.try_into_slice() {
Ok(x) => ElementsRepr::Slice(x.iter_mut()),
Err(self_) => ElementsRepr::Counted(self_.into_elements_base()),
Err(self_) => ElementsRepr::Counted(self_.into_elements_base_preserve_order()),
},
}
}
Expand Down Expand Up @@ -397,9 +395,9 @@ pub(crate) struct ElementsBaseMut<'a, A, D> {
}

impl<'a, A, D: Dimension> ElementsBaseMut<'a, A, D> {
pub fn new(v: ArrayViewMut<'a, A, D>) -> Self {
pub fn new<F: OrderOption>(v: ArrayViewMut<'a, A, D>) -> Self {
ElementsBaseMut {
inner: v.into_base_iter(),
inner: v.into_base_iter::<F>(),
life: PhantomData,
}
}
Expand Down
4 changes: 2 additions & 2 deletions src/iterators/chunks.rs
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ where
type IntoIter = ExactChunksIter<'a, A, D>;
fn into_iter(self) -> Self::IntoIter {
ExactChunksIter {
iter: self.base.into_elements_base(),
iter: self.base.into_elements_base_any_order(),
chunk: self.chunk,
inner_strides: self.inner_strides,
}
Expand Down Expand Up @@ -169,7 +169,7 @@ where
type IntoIter = ExactChunksIterMut<'a, A, D>;
fn into_iter(self) -> Self::IntoIter {
ExactChunksIterMut {
iter: self.base.into_elements_base(),
iter: self.base.into_elements_base_any_order(),
chunk: self.chunk,
inner_strides: self.inner_strides,
}
Expand Down
5 changes: 3 additions & 2 deletions src/iterators/lanes.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ use std::marker::PhantomData;
use crate::imp_prelude::*;
use crate::{Layout, NdProducer};
use crate::iterators::Baseiter;
use crate::iterators::base::NoOptimization;

impl_ndproducer! {
['a, A, D: Dimension]
Expand Down Expand Up @@ -83,7 +84,7 @@ where
type IntoIter = LanesIter<'a, A, D>;
fn into_iter(self) -> Self::IntoIter {
LanesIter {
iter: self.base.into_base_iter(),
iter: self.base.into_base_iter::<NoOptimization>(),
inner_len: self.inner_len,
inner_stride: self.inner_stride,
life: PhantomData,
Expand Down Expand Up @@ -134,7 +135,7 @@ where
type IntoIter = LanesIterMut<'a, A, D>;
fn into_iter(self) -> Self::IntoIter {
LanesIterMut {
iter: self.base.into_base_iter(),
iter: self.base.into_base_iter::<NoOptimization>(),
inner_len: self.inner_len,
inner_stride: self.inner_stride,
life: PhantomData,
Expand Down
2 changes: 1 addition & 1 deletion src/iterators/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
mod macros;

mod axis;
mod base;
pub(crate) mod base;
mod chunks;
mod into_iter;
pub mod iter;
Expand Down
2 changes: 1 addition & 1 deletion src/iterators/windows.rs
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ where
type IntoIter = WindowsIter<'a, A, D>;
fn into_iter(self) -> Self::IntoIter {
WindowsIter {
iter: self.base.into_elements_base(),
iter: self.base.into_elements_base_preserve_order(),
window: self.window,
strides: self.strides,
}
Expand Down
3 changes: 1 addition & 2 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -142,8 +142,7 @@ pub use crate::slice::{
MultiSliceArg, NewAxis, Slice, SliceArg, SliceInfo, SliceInfoElem, SliceNextDim,
};

use crate::iterators::Baseiter;
use crate::iterators::{ElementsBase, ElementsBaseMut, Iter, IterMut};
use crate::iterators::{ElementsBase, ElementsBaseMut};

pub use crate::arraytraits::AsArray;
#[cfg(feature = "std")]
Expand Down

0 comments on commit f31add8

Please sign in to comment.