Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Basic support for axis simplification and arbitrary order in iterators #979

Open
wants to merge 8 commits into
base: master
Choose a base branch
from
72 changes: 72 additions & 0 deletions benches/iter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,78 @@ fn iter_filter_sum_2d_stride_f32(bench: &mut Bencher)
bench.iter(|| b.iter().filter(|&&x| x < 75.).sum::<f32>());
}

#[bench]
fn iter_sum_2d_row_matrix(bench: &mut Bencher)
{
let a = Array::from_iter(0i32..64 * 64);
let v = a.view().insert_axis(Axis(1));
bench.iter(|| {
let mut s = 0;
for &elt in v.iter() {
s += elt;
}
s
});
}

#[bench]
fn iter_sum_2d_row_matrix_for_strided(bench: &mut Bencher)
{
let a = Array::from_iter(0i32..64 * 64).slice_move(s![..;2]);
let v = a.view().insert_axis(Axis(1));
bench.iter(|| {
let mut s = 0;
for &elt in v.iter() {
s += elt;
}
s
});
}

#[bench]
fn iter_sum_2d_row_matrix_sum_strided(bench: &mut Bencher)
{
let a = Array::from_iter(0i32..64 * 64).slice_move(s![..;2]);
let v = a.view().insert_axis(Axis(1));
bench.iter(|| v.iter().sum::<i32>());
}

#[bench]
fn iter_sum_2d_col_matrix(bench: &mut Bencher)
{
let a = Array::from_iter(0i32..64 * 64);
let v = a.view().insert_axis(Axis(0));
bench.iter(|| {
let mut s = 0;
for &elt in v.iter() {
s += elt;
}
s
});
}

#[bench]
fn iter_sum_2d_col_matrix_for_strided(bench: &mut Bencher)
{
let a = Array::from_iter(0i32..64 * 64).slice_move(s![..;2]);
let v = a.view().insert_axis(Axis(0));
bench.iter(|| {
let mut s = 0;
for &elt in v.iter() {
s += elt;
}
s
});
}

#[bench]
fn iter_sum_2d_col_matrix_sum_strided(bench: &mut Bencher)
{
let a = Array::from_iter(0i32..64 * 64).slice_move(s![..;2]);
let v = a.view().insert_axis(Axis(0));
bench.iter(|| v.iter().sum::<i32>());
}

#[bench]
fn iter_rev_step_by_contiguous(bench: &mut Bencher)
{
Expand Down
157 changes: 157 additions & 0 deletions src/dimension/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -313,6 +313,20 @@ pub trait DimensionExt
/// *Panics* if `axis` is out of bounds.
#[track_caller]
fn set_axis(&mut self, axis: Axis, value: Ix);

/// Get as stride
#[inline]
fn get_stride(&self, axis: Axis) -> isize
{
self.axis(axis) as isize
}

/// Set as stride
#[inline]
fn set_stride(&mut self, axis: Axis, value: isize)
{
self.set_axis(axis, value as usize)
}
}

impl<D> DimensionExt for D
Expand Down Expand Up @@ -745,6 +759,32 @@ where D: Dimension
}
}

/// Attempt to merge axes if possible, starting from the back
///
/// Given axes [Axis(0), Axis(1), Axis(2), Axis(3)] this attempts
/// to merge all axes one by one into Axis(3); when/if this fails,
/// it attempts to merge the rest of the axes together into the next
/// axis in line, for example a result could be:
///
/// [1, Axis(0) + Axis(1), 1, Axis(2) + Axis(3)] where `+` would
/// mean axes were merged.
pub(crate) fn merge_axes_from_the_back<D>(dim: &mut D, strides: &mut D)
where D: Dimension
{
debug_assert_eq!(dim.ndim(), strides.ndim());
match dim.ndim() {
0 | 1 => {}
n => {
let mut last = n - 1;
for i in (0..last).rev() {
if !merge_axes(dim, strides, Axis(i), Axis(last)) {
last = i;
jturner314 marked this conversation as resolved.
Show resolved Hide resolved
}
}
}
}
}

/// Move the axis which has the smallest absolute stride and a length
/// greater than one to be the last axis.
pub fn move_min_stride_axis_to_last<D>(dim: &mut D, strides: &mut D)
Expand All @@ -771,6 +811,67 @@ where D: Dimension
}
}

/// Remove axes with length one, except never removing the last axis.
///
/// This only has effect on dynamic dimensions.
pub(crate) fn squeeze<D>(dim: &mut D, strides: &mut D)
where D: Dimension
{
if let Some(_) = D::NDIM {
return;
}
debug_assert_eq!(dim.ndim(), strides.ndim());

// Count axes with dim == 1; we keep axes with d == 0 or d > 1
let mut ndim_new = 0;
for &d in dim.slice() {
if d != 1 {
ndim_new += 1;
}
}
ndim_new = Ord::max(1, ndim_new);
let mut new_dim = D::zeros(ndim_new);
let mut new_strides = D::zeros(ndim_new);
let mut i = 0;
for (&d, &s) in izip!(dim.slice(), strides.slice()) {
if d != 1 {
new_dim[i] = d;
new_strides[i] = s;
i += 1;
}
}
if i == 0 {
new_dim[i] = 1;
new_strides[i] = 1;
}
*dim = new_dim;
*strides = new_strides;
}

/// Sort axes to standard/row major order, i.e Axis(0) has biggest stride and Axis(n - 1) least
/// stride
///
/// The axes are sorted according to the .abs() of their stride.
pub(crate) fn sort_axes_to_standard<D>(dim: &mut D, strides: &mut D)
where D: Dimension
{
debug_assert!(dim.ndim() > 1);
debug_assert_eq!(dim.ndim(), strides.ndim());
// bubble sort axes
let mut changed = true;
while changed {
changed = false;
for i in 0..dim.ndim() - 1 {
// make sure higher stride axes sort before.
if strides.get_stride(Axis(i)).abs() < strides.get_stride(Axis(i + 1)).abs() {
changed = true;
dim.slice_mut().swap(i, i + 1);
strides.slice_mut().swap(i, i + 1);
}
}
}
}

#[cfg(test)]
mod test
{
Expand All @@ -780,9 +881,11 @@ mod test
can_index_slice_not_custom,
extended_gcd,
max_abs_offset_check_overflow,
merge_axes_from_the_back,
slice_min_max,
slices_intersect,
solve_linear_diophantine_eq,
squeeze,
IntoDimension,
};
use crate::error::{from_kind, ErrorKind};
Expand Down Expand Up @@ -1132,4 +1235,58 @@ mod test
s![.., 3..;6, NewAxis]
));
}

#[test]
#[cfg(feature = "std")]
fn test_squeeze()
{
let dyndim = Dim::<&[usize]>;

let mut d = dyndim(&[1, 2, 1, 1, 3, 1]);
let mut s = dyndim(&[!0, !0, !0, 9, 10, !0]);
let dans = dyndim(&[2, 3]);
let sans = dyndim(&[!0, 10]);
squeeze(&mut d, &mut s);
assert_eq!(d, dans);
assert_eq!(s, sans);

let mut d = dyndim(&[1, 1]);
let mut s = dyndim(&[3, 4]);
let dans = dyndim(&[1]);
let sans = dyndim(&[1]);
squeeze(&mut d, &mut s);
assert_eq!(d, dans);
assert_eq!(s, sans);

let mut d = dyndim(&[0, 1, 3, 4]);
let mut s = dyndim(&[2, 3, 4, 5]);
let dans = dyndim(&[0, 3, 4]);
let sans = dyndim(&[2, 4, 5]);
squeeze(&mut d, &mut s);
assert_eq!(d, dans);
assert_eq!(s, sans);
}

#[test]
fn test_merge_axes_from_the_back()
{
let dyndim = Dim::<&[usize]>;

let mut d = Dim([3, 4, 5]);
let mut s = Dim([20, 5, 1]);
merge_axes_from_the_back(&mut d, &mut s);
assert_eq!(d, Dim([1, 1, 60]));
assert_eq!(s, Dim([20, 5, 1]));

let mut d = Dim([3, 4, 5, 2]);
let mut s = Dim([80, 20, 2, 1]);
merge_axes_from_the_back(&mut d, &mut s);
assert_eq!(d, Dim([1, 12, 1, 10]));
assert_eq!(s, Dim([80, 20, 2, 1]));
let mut d = d.into_dyn();
let mut s = s.into_dyn();
squeeze(&mut d, &mut s);
assert_eq!(d, dyndim(&[12, 10]));
assert_eq!(s, dyndim(&[20, 1]));
}
}
Loading