Skip to content

Commit

Permalink
Move upcast broadcast helper to dimension mod
Browse files Browse the repository at this point in the history
This will allow upcast to be reused in other functions, such as the
upcoming ArrayView::broadcast_ref.
  • Loading branch information
roblabla committed Oct 14, 2022
1 parent 0740695 commit e797a77
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 51 deletions.
51 changes: 51 additions & 0 deletions src/dimension/broadcast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,57 @@ where
Ok(out)
}

/// Return new stride when trying to grow `from` into shape `to`
///
/// Broadcasting works by returning a "fake stride" where elements
/// to repeat are in axes with 0 stride, so that several indexes point
/// to the same element.
///
/// **Note:** Cannot be used for mutable iterators, since repeating
/// elements would create aliasing pointers.
pub(crate) fn upcast<D: Dimension, E: Dimension>(to: &D, from: &E, stride: &E) -> Option<D> {
// Make sure the product of non-zero axis lengths does not exceed
// `isize::MAX`. This is the only safety check we need to perform
// because all the other constraints of `ArrayBase` are guaranteed
// to be met since we're starting from a valid `ArrayBase`.
let _ = size_of_shape_checked(to).ok()?;

let mut new_stride = to.clone();
// begin at the back (the least significant dimension)
// size of the axis has to either agree or `from` has to be 1
if to.ndim() < from.ndim() {
return None;
}

{
let mut new_stride_iter = new_stride.slice_mut().iter_mut().rev();
for ((er, es), dr) in from
.slice()
.iter()
.rev()
.zip(stride.slice().iter().rev())
.zip(new_stride_iter.by_ref())
{
/* update strides */
if *dr == *er {
/* keep stride */
*dr = *es;
} else if *er == 1 {
/* dead dimension, zero stride */
*dr = 0
} else {
return None;
}
}

/* set remaining strides to zero */
for dr in new_stride_iter {
*dr = 0;
}
}
Some(new_stride)
}

pub trait DimMax<Other: Dimension> {
/// The resulting dimension type after broadcasting.
type Output: Dimension;
Expand Down
52 changes: 1 addition & 51 deletions src/impl_methods.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ use crate::dimension::{
abs_index, axes_of, do_slice, merge_axes, move_min_stride_axis_to_last,
offset_from_low_addr_ptr_to_logical_ptr, size_of_shape_checked, stride_offset, Axes,
};
use crate::dimension::broadcast::co_broadcast;
use crate::dimension::broadcast::{co_broadcast, upcast};
use crate::dimension::reshape_dim;
use crate::error::{self, ErrorKind, ShapeError, from_kind};
use crate::math_cell::MathCell;
Expand Down Expand Up @@ -2036,56 +2036,6 @@ where
E: IntoDimension,
S: Data,
{
/// Return new stride when trying to grow `from` into shape `to`
///
/// Broadcasting works by returning a "fake stride" where elements
/// to repeat are in axes with 0 stride, so that several indexes point
/// to the same element.
///
/// **Note:** Cannot be used for mutable iterators, since repeating
/// elements would create aliasing pointers.
fn upcast<D: Dimension, E: Dimension>(to: &D, from: &E, stride: &E) -> Option<D> {
// Make sure the product of non-zero axis lengths does not exceed
// `isize::MAX`. This is the only safety check we need to perform
// because all the other constraints of `ArrayBase` are guaranteed
// to be met since we're starting from a valid `ArrayBase`.
let _ = size_of_shape_checked(to).ok()?;

let mut new_stride = to.clone();
// begin at the back (the least significant dimension)
// size of the axis has to either agree or `from` has to be 1
if to.ndim() < from.ndim() {
return None;
}

{
let mut new_stride_iter = new_stride.slice_mut().iter_mut().rev();
for ((er, es), dr) in from
.slice()
.iter()
.rev()
.zip(stride.slice().iter().rev())
.zip(new_stride_iter.by_ref())
{
/* update strides */
if *dr == *er {
/* keep stride */
*dr = *es;
} else if *er == 1 {
/* dead dimension, zero stride */
*dr = 0
} else {
return None;
}
}

/* set remaining strides to zero */
for dr in new_stride_iter {
*dr = 0;
}
}
Some(new_stride)
}
let dim = dim.into_dimension();

// Note: zero strides are safe precisely because we return an read-only view
Expand Down

0 comments on commit e797a77

Please sign in to comment.