diff --git a/benches/iter.rs b/benches/iter.rs index 77f511745..738b72ce1 100644 --- a/benches/iter.rs +++ b/benches/iter.rs @@ -87,6 +87,78 @@ fn iter_filter_sum_2d_stride_f32(bench: &mut Bencher) bench.iter(|| b.iter().filter(|&&x| x < 75.).sum::()); } +#[bench] +fn iter_sum_2d_row_matrix(bench: &mut Bencher) +{ + let a = Array::from_iter(0i32..64 * 64); + let v = a.view().insert_axis(Axis(1)); + bench.iter(|| { + let mut s = 0; + for &elt in v.iter() { + s += elt; + } + s + }); +} + +#[bench] +fn iter_sum_2d_row_matrix_for_strided(bench: &mut Bencher) +{ + let a = Array::from_iter(0i32..64 * 64).slice_move(s![..;2]); + let v = a.view().insert_axis(Axis(1)); + bench.iter(|| { + let mut s = 0; + for &elt in v.iter() { + s += elt; + } + s + }); +} + +#[bench] +fn iter_sum_2d_row_matrix_sum_strided(bench: &mut Bencher) +{ + let a = Array::from_iter(0i32..64 * 64).slice_move(s![..;2]); + let v = a.view().insert_axis(Axis(1)); + bench.iter(|| v.iter().sum::()); +} + +#[bench] +fn iter_sum_2d_col_matrix(bench: &mut Bencher) +{ + let a = Array::from_iter(0i32..64 * 64); + let v = a.view().insert_axis(Axis(0)); + bench.iter(|| { + let mut s = 0; + for &elt in v.iter() { + s += elt; + } + s + }); +} + +#[bench] +fn iter_sum_2d_col_matrix_for_strided(bench: &mut Bencher) +{ + let a = Array::from_iter(0i32..64 * 64).slice_move(s![..;2]); + let v = a.view().insert_axis(Axis(0)); + bench.iter(|| { + let mut s = 0; + for &elt in v.iter() { + s += elt; + } + s + }); +} + +#[bench] +fn iter_sum_2d_col_matrix_sum_strided(bench: &mut Bencher) +{ + let a = Array::from_iter(0i32..64 * 64).slice_move(s![..;2]); + let v = a.view().insert_axis(Axis(0)); + bench.iter(|| v.iter().sum::()); +} + #[bench] fn iter_rev_step_by_contiguous(bench: &mut Bencher) { diff --git a/src/dimension/mod.rs b/src/dimension/mod.rs index e1563613e..888feaf66 100644 --- a/src/dimension/mod.rs +++ b/src/dimension/mod.rs @@ -313,6 +313,20 @@ pub trait DimensionExt /// *Panics* if `axis` is out of bounds. #[track_caller] fn set_axis(&mut self, axis: Axis, value: Ix); + + /// Get as stride + #[inline] + fn get_stride(&self, axis: Axis) -> isize + { + self.axis(axis) as isize + } + + /// Set as stride + #[inline] + fn set_stride(&mut self, axis: Axis, value: isize) + { + self.set_axis(axis, value as usize) + } } impl DimensionExt for D @@ -745,6 +759,32 @@ where D: Dimension } } +/// Attempt to merge axes if possible, starting from the back +/// +/// Given axes [Axis(0), Axis(1), Axis(2), Axis(3)] this attempts +/// to merge all axes one by one into Axis(3); when/if this fails, +/// it attempts to merge the rest of the axes together into the next +/// axis in line, for example a result could be: +/// +/// [1, Axis(0) + Axis(1), 1, Axis(2) + Axis(3)] where `+` would +/// mean axes were merged. +pub(crate) fn merge_axes_from_the_back(dim: &mut D, strides: &mut D) +where D: Dimension +{ + debug_assert_eq!(dim.ndim(), strides.ndim()); + match dim.ndim() { + 0 | 1 => {} + n => { + let mut last = n - 1; + for i in (0..last).rev() { + if !merge_axes(dim, strides, Axis(i), Axis(last)) { + last = i; + } + } + } + } +} + /// Move the axis which has the smallest absolute stride and a length /// greater than one to be the last axis. pub fn move_min_stride_axis_to_last(dim: &mut D, strides: &mut D) @@ -771,6 +811,67 @@ where D: Dimension } } +/// Remove axes with length one, except never removing the last axis. +/// +/// This only has effect on dynamic dimensions. +pub(crate) fn squeeze(dim: &mut D, strides: &mut D) +where D: Dimension +{ + if let Some(_) = D::NDIM { + return; + } + debug_assert_eq!(dim.ndim(), strides.ndim()); + + // Count axes with dim == 1; we keep axes with d == 0 or d > 1 + let mut ndim_new = 0; + for &d in dim.slice() { + if d != 1 { + ndim_new += 1; + } + } + ndim_new = Ord::max(1, ndim_new); + let mut new_dim = D::zeros(ndim_new); + let mut new_strides = D::zeros(ndim_new); + let mut i = 0; + for (&d, &s) in izip!(dim.slice(), strides.slice()) { + if d != 1 { + new_dim[i] = d; + new_strides[i] = s; + i += 1; + } + } + if i == 0 { + new_dim[i] = 1; + new_strides[i] = 1; + } + *dim = new_dim; + *strides = new_strides; +} + +/// Sort axes to standard/row major order, i.e Axis(0) has biggest stride and Axis(n - 1) least +/// stride +/// +/// The axes are sorted according to the .abs() of their stride. +pub(crate) fn sort_axes_to_standard(dim: &mut D, strides: &mut D) +where D: Dimension +{ + debug_assert!(dim.ndim() > 1); + debug_assert_eq!(dim.ndim(), strides.ndim()); + // bubble sort axes + let mut changed = true; + while changed { + changed = false; + for i in 0..dim.ndim() - 1 { + // make sure higher stride axes sort before. + if strides.get_stride(Axis(i)).abs() < strides.get_stride(Axis(i + 1)).abs() { + changed = true; + dim.slice_mut().swap(i, i + 1); + strides.slice_mut().swap(i, i + 1); + } + } + } +} + #[cfg(test)] mod test { @@ -780,9 +881,11 @@ mod test can_index_slice_not_custom, extended_gcd, max_abs_offset_check_overflow, + merge_axes_from_the_back, slice_min_max, slices_intersect, solve_linear_diophantine_eq, + squeeze, IntoDimension, }; use crate::error::{from_kind, ErrorKind}; @@ -1132,4 +1235,58 @@ mod test s![.., 3..;6, NewAxis] )); } + + #[test] + #[cfg(feature = "std")] + fn test_squeeze() + { + let dyndim = Dim::<&[usize]>; + + let mut d = dyndim(&[1, 2, 1, 1, 3, 1]); + let mut s = dyndim(&[!0, !0, !0, 9, 10, !0]); + let dans = dyndim(&[2, 3]); + let sans = dyndim(&[!0, 10]); + squeeze(&mut d, &mut s); + assert_eq!(d, dans); + assert_eq!(s, sans); + + let mut d = dyndim(&[1, 1]); + let mut s = dyndim(&[3, 4]); + let dans = dyndim(&[1]); + let sans = dyndim(&[1]); + squeeze(&mut d, &mut s); + assert_eq!(d, dans); + assert_eq!(s, sans); + + let mut d = dyndim(&[0, 1, 3, 4]); + let mut s = dyndim(&[2, 3, 4, 5]); + let dans = dyndim(&[0, 3, 4]); + let sans = dyndim(&[2, 4, 5]); + squeeze(&mut d, &mut s); + assert_eq!(d, dans); + assert_eq!(s, sans); + } + + #[test] + fn test_merge_axes_from_the_back() + { + let dyndim = Dim::<&[usize]>; + + let mut d = Dim([3, 4, 5]); + let mut s = Dim([20, 5, 1]); + merge_axes_from_the_back(&mut d, &mut s); + assert_eq!(d, Dim([1, 1, 60])); + assert_eq!(s, Dim([20, 5, 1])); + + let mut d = Dim([3, 4, 5, 2]); + let mut s = Dim([80, 20, 2, 1]); + merge_axes_from_the_back(&mut d, &mut s); + assert_eq!(d, Dim([1, 12, 1, 10])); + assert_eq!(s, Dim([80, 20, 2, 1])); + let mut d = d.into_dyn(); + let mut s = s.into_dyn(); + squeeze(&mut d, &mut s); + assert_eq!(d, dyndim(&[12, 10])); + assert_eq!(s, dyndim(&[20, 1])); + } } diff --git a/src/iterators/mod.rs b/src/iterators/mod.rs index 4851b2827..1f3c3472c 100644 --- a/src/iterators/mod.rs +++ b/src/iterators/mod.rs @@ -19,24 +19,61 @@ use alloc::vec::Vec; use std::iter::FromIterator; use std::marker::PhantomData; use std::ptr; +use std::slice::{self, Iter as SliceIter, IterMut as SliceIterMut}; -use crate::Ix1; +use crate::imp_prelude::*; -use super::{ArrayBase, ArrayView, ArrayViewMut, Axis, Data, NdProducer, RemoveAxis}; -use super::{Dimension, Ix, Ixs}; +use super::NdProducer; pub use self::chunks::{ExactChunks, ExactChunksIter, ExactChunksIterMut, ExactChunksMut}; pub use self::into_iter::IntoIter; pub use self::lanes::{Lanes, LanesMut}; pub use self::windows::Windows; -use std::slice::{self, Iter as SliceIter, IterMut as SliceIterMut}; +use crate::dimension; + +/// No traversal optmizations that would change element order or axis dimensions are permitted. +/// +/// This option is suitable for example for the indexed iterator. +pub(crate) enum NoOptimization {} + +/// Preserve element iteration order, but modify dimensions if profitable; for example we can +/// change from shape [10, 1] to [1, 10], because that axis has len == 1, without consequence here. +/// +/// This option is suitable for example for the default .iter() iterator. +pub(crate) enum PreserveOrder {} + +/// Allow use of arbitrary element iteration order +/// +/// This option is suitable for example for an arbitrary order iterator. +pub(crate) enum ArbitraryOrder {} + +pub(crate) trait OrderOption +{ + const ALLOW_REMOVE_REDUNDANT_AXES: bool = false; + const ALLOW_ARBITRARY_ORDER: bool = false; +} + +impl OrderOption for NoOptimization {} + +impl OrderOption for PreserveOrder +{ + const ALLOW_REMOVE_REDUNDANT_AXES: bool = true; +} + +impl OrderOption for ArbitraryOrder +{ + const ALLOW_REMOVE_REDUNDANT_AXES: bool = true; + const ALLOW_ARBITRARY_ORDER: bool = true; +} /// Base for iterators over all axes. /// /// Iterator element type is `*mut A`. +/// +/// `F` is for layout/iteration order flags #[derive(Debug)] -pub struct Baseiter +pub(crate) struct Baseiter { ptr: *mut A, dim: D, @@ -50,13 +87,46 @@ impl Baseiter /// to be correct to avoid performing an unsafe pointer offset while /// iterating. #[inline] - pub unsafe fn new(ptr: *mut A, len: D, stride: D) -> Baseiter + pub unsafe fn new(ptr: *mut A, dim: D, strides: D) -> Baseiter { + Self::new_with_order::(ptr, dim, strides) + } +} + +impl Baseiter +{ + /// Creating a Baseiter is unsafe because shape and stride parameters need + /// to be correct to avoid performing an unsafe pointer offset while + /// iterating. + #[inline] + pub unsafe fn new_with_order(mut ptr: *mut A, mut dim: D, mut strides: D) -> Baseiter + { + debug_assert_eq!(dim.ndim(), strides.ndim()); + if Flags::ALLOW_ARBITRARY_ORDER { + // iterate in memory order; merge axes if possible + // make all axes positive and put the pointer back to the first element in memory + let offset = dimension::offset_from_low_addr_ptr_to_logical_ptr(&dim, &strides); + ptr = ptr.sub(offset); + for i in 0..strides.ndim() { + let s = strides.get_stride(Axis(i)); + if s < 0 { + strides.set_stride(Axis(i), -s); + } + } + dimension::sort_axes_to_standard(&mut dim, &mut strides); + } + + if Flags::ALLOW_REMOVE_REDUNDANT_AXES { + // preserve element order but shift dimensions + dimension::merge_axes_from_the_back(&mut dim, &mut strides); + dimension::squeeze(&mut dim, &mut strides); + } + Baseiter { ptr, - index: len.first_index(), - dim: len, - strides: stride, + index: dim.first_index(), + dim, + strides, } } } @@ -341,7 +411,7 @@ pub struct Iter<'a, A, D> /// Counted read only iterator #[derive(Debug)] -pub struct ElementsBase<'a, A, D> +pub(crate) struct ElementsBase<'a, A, D> { inner: Baseiter, life: PhantomData<&'a A>, @@ -362,7 +432,7 @@ pub struct IterMut<'a, A, D> /// /// Iterator element type is `&'a mut A`. #[derive(Debug)] -pub struct ElementsBaseMut<'a, A, D> +pub(crate) struct ElementsBaseMut<'a, A, D> { inner: Baseiter, life: PhantomData<&'a mut A>, @@ -829,7 +899,7 @@ impl<'a, A> DoubleEndedIterator for LanesIterMut<'a, A, Ix1> } #[derive(Debug)] -pub struct AxisIterCore +struct AxisIterCore { /// Index along the axis of the value of `.next()`, relative to the start /// of the axis. @@ -1530,10 +1600,11 @@ send_sync_read_write!(ElementsBaseMut); /// (Trait used internally) An iterator that we trust /// to deliver exactly as many items as it said it would. /// +/// # Safety +/// /// The iterator must produce exactly the number of elements it reported or /// diverge before reaching the end. -#[allow(clippy::missing_safety_doc)] // not nameable downstream -pub unsafe trait TrustedIterator {} +pub(crate) unsafe trait TrustedIterator {} use crate::indexes::IndicesIterF; use crate::iter::IndicesIter; @@ -1558,14 +1629,14 @@ unsafe impl TrustedIterator for IndicesIterF where D: Dimension {} unsafe impl TrustedIterator for IntoIter where D: Dimension {} /// Like Iterator::collect, but only for trusted length iterators -pub fn to_vec(iter: I) -> Vec +pub(crate) fn to_vec(iter: I) -> Vec where I: TrustedIterator + ExactSizeIterator { to_vec_mapped(iter, |x| x) } /// Like Iterator::collect, but only for trusted length iterators -pub fn to_vec_mapped(iter: I, mut f: F) -> Vec +pub(crate) fn to_vec_mapped(iter: I, mut f: F) -> Vec where I: TrustedIterator + ExactSizeIterator, F: FnMut(I::Item) -> B, @@ -1586,3 +1657,158 @@ where debug_assert_eq!(size, result.len()); result } + +#[cfg(test)] +#[cfg(feature = "std")] +mod tests +{ + use super::Baseiter; + use super::{ArbitraryOrder, NoOptimization, PreserveOrder}; + use crate::prelude::*; + use itertools::assert_equal; + use itertools::Itertools; + + // 3-d axis swaps + fn swaps() -> impl Iterator> + { + vec![ + vec![], + vec![(0, 1)], + vec![(0, 2)], + vec![(1, 2)], + vec![(0, 1), (1, 2)], + vec![(0, 1), (0, 2)], + ] + .into_iter() + } + + // 3-d axis inverts + fn inverts() -> impl Iterator> + { + vec![ + vec![], + vec![Axis(0)], + vec![Axis(1)], + vec![Axis(2)], + vec![Axis(0), Axis(1)], + vec![Axis(0), Axis(2)], + vec![Axis(1), Axis(2)], + vec![Axis(0), Axis(1), Axis(2)], + ] + .into_iter() + } + + #[test] + fn test_arbitrary_order() + { + for swap in swaps() { + for invert in inverts() { + for &slice in &[false, true] { + // pattern is 0, 1; 4, 5; 8, 9; etc.. + let mut a = Array::from_iter(0..24) + .into_shape_with_order((3, 4, 2)) + .unwrap(); + if slice { + a.slice_collapse(s![.., ..;2, ..]); + } + for &(i, j) in &swap { + a.swap_axes(i, j); + } + for &i in &invert { + a.invert_axis(i); + } + unsafe { + // Should have in-memory order for arbitrary order + let iter = Baseiter::new_with_order::(a.as_mut_ptr(), a.dim, a.strides); + if !slice { + assert_equal(iter.map(|ptr| *ptr), 0..a.len()); + } else { + assert_eq!(iter.map(|ptr| *ptr).collect_vec(), + (0..a.len() * 2).filter(|&x| (x / 2) % 2 == 0).collect_vec()); + } + } + } + } + } + } + + #[test] + fn test_logical_order() + { + for swap in swaps() { + for invert in inverts() { + for &slice in &[false, true] { + let mut a = Array::from_iter(0..24) + .into_shape_with_order((3, 4, 2)) + .unwrap(); + for &(i, j) in &swap { + a.swap_axes(i, j); + } + for &i in &invert { + a.invert_axis(i); + } + if slice { + a.slice_collapse(s![.., ..;2, ..]); + } + + unsafe { + let mut iter = Baseiter::new_with_order::(a.as_mut_ptr(), a.dim, a.strides); + let mut index = Dim([0, 0, 0]); + let mut elts = 0; + while let Some(elt) = iter.next() { + assert_eq!(*elt, a[index]); + if let Some(index_) = a.raw_dim().next_for(index) { + index = index_; + } + elts += 1; + } + assert_eq!(elts, a.len()); + } + } + } + } + } + + #[test] + fn test_preserve_order() + { + for swap in swaps() { + for invert in inverts() { + for &slice in &[false, true] { + let mut a = Array::from_iter(0..20) + .into_shape_with_order((2, 10, 1)) + .unwrap(); + for &(i, j) in &swap { + a.swap_axes(i, j); + } + for &i in &invert { + a.invert_axis(i); + } + if slice { + a.slice_collapse(s![.., ..;2, ..]); + } + + unsafe { + let mut iter = Baseiter::new_with_order::(a.as_mut_ptr(), a.dim, a.strides); + + // check that axes have been merged (when it's easy to check) + if a.shape() == &[2, 10, 1] && invert.is_empty() { + assert_eq!(iter.dim, Dim([1, 1, 20])); + } + + let mut index = Dim([0, 0, 0]); + let mut elts = 0; + while let Some(elt) = iter.next() { + assert_eq!(*elt, a[index]); + if let Some(index_) = a.raw_dim().next_for(index) { + index = index_; + } + elts += 1; + } + assert_eq!(elts, a.len()); + } + } + } + } + } +} diff --git a/src/lib.rs b/src/lib.rs index 37af0adfe..fca7e443c 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -10,6 +10,7 @@ #![doc(html_logo_url = "https://rust-ndarray.github.io/images/rust-ndarray_logo.svg")] #![allow( unstable_name_collisions, // our `PointerExt` collides with upcoming inherent methods on `NonNull` + clippy::redundant_pattern_matching, // if let is sometimes good style clippy::many_single_char_names, clippy::deref_addrof, clippy::unreadable_literal,