Skip to content

Commit

Permalink
shape: Add optional order argument for into_shape
Browse files Browse the repository at this point in the history
  • Loading branch information
bluss committed Jul 27, 2023
1 parent 014290c commit dad6bbd
Show file tree
Hide file tree
Showing 2 changed files with 84 additions and 13 deletions.
59 changes: 46 additions & 13 deletions src/impl_methods.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1885,38 +1885,71 @@ where
}

/// Transform the array into `shape`; any shape with the same number of
/// elements is accepted, but the source array or view must be in standard
/// or column-major (Fortran) layout.
/// elements is accepted, but the source array must be in contiguous row-major (C) or
/// column-major (F) layout.
///
/// **Note** that `.into_shape()` "moves" elements differently depending on if the input array
/// is C-contig or F-contig, it follows the index order that corresponds to the memory
/// order. If this is not wanted, use `.to_shape()`.
///
/// If a memory ordering is specified (optional) in the shape argument, the operation
/// will only succeed if the input has this memory order.
///
/// **Errors** if the shapes don't have the same number of elements.<br>
/// **Errors** if the input array is not c- or f-contiguous.
/// **Errors** if the input array is not c- or f-contiguous.<br>
/// **Errors** if a memory ordering is requested that is not compatible with the array.<br>
///
/// If shape is not given: use memory layout of incoming array. Row major arrays are
/// reshaped using row major index ordering, column major arrays with column major index
/// ordering.
///
/// ```
/// use ndarray::{aview1, aview2};
/// use ndarray::Order;
///
/// assert!(
/// aview1(&[1., 2., 3., 4.]).into_shape((2, 2)).unwrap()
/// == aview2(&[[1., 2.],
/// [3., 4.]])
/// );
///
/// assert!(
/// aview1(&[1., 2., 3., 4.]).into_shape(((2, 2), Order::ColumnMajor)).unwrap()
/// == aview2(&[[1., 3.],
/// [2., 4.]])
/// );
/// ```
pub fn into_shape<E>(self, shape: E) -> Result<ArrayBase<S, E::Dim>, ShapeError>
where
E: IntoDimension,
E: ShapeArg,
{
let (shape, order) = shape.into_shape_and_order();
self.into_shape_order(shape, order)
}

fn into_shape_order<E>(self, shape: E, order: Option<Order>) -> Result<ArrayBase<S, E>, ShapeError>
where
E: Dimension,
{
let shape = shape.into_dimension();
if size_of_shape_checked(&shape) != Ok(self.dim.size()) {
return Err(error::incompatible_shapes(&self.dim, &shape));
}
// Check if contiguous, if not => copy all, else just adapt strides

// Check if contiguous, then we can change shape
let require_order = order.is_some();
unsafe {
// safe because arrays are contiguous and len is unchanged
if self.is_standard_layout() {
Ok(self.with_strides_dim(shape.default_strides(), shape))
} else if self.ndim() > 1 && self.raw_view().reversed_axes().is_standard_layout() {
Ok(self.with_strides_dim(shape.fortran_strides(), shape))
} else {
Err(error::from_kind(error::ErrorKind::IncompatibleLayout))
match order {
None | Some(Order::RowMajor) if self.is_standard_layout() => {
Ok(self.with_strides_dim(shape.default_strides(), shape))
}
None | Some(Order::ColumnMajor) if (require_order || self.ndim() > 1) &&
self.raw_view().reversed_axes().is_standard_layout() =>
{
Ok(self.with_strides_dim(shape.fortran_strides(), shape))
}
_otherwise => Err(error::from_kind(error::ErrorKind::IncompatibleLayout))
}
}
}
Expand All @@ -1932,7 +1965,7 @@ where
self.into_shape_clone_order(shape, order)
}

pub fn into_shape_clone_order<E>(self, shape: E, order: Order)
fn into_shape_clone_order<E>(self, shape: E, order: Order)
-> Result<ArrayBase<S, E>, ShapeError>
where
S: DataOwned,
Expand Down Expand Up @@ -2004,7 +2037,7 @@ where
A: Clone,
E: IntoDimension,
{
return self.clone().into_shape_clone(shape).unwrap();
//return self.clone().into_shape_clone(shape).unwrap();
let shape = shape.into_dimension();
if size_of_shape_checked(&shape) != Ok(self.dim.size()) {
panic!(
Expand Down
38 changes: 38 additions & 0 deletions tests/reshape.rs
Original file line number Diff line number Diff line change
Expand Up @@ -230,3 +230,41 @@ fn to_shape_broadcast() {
}
}
}


#[test]
fn into_shape_easy() {
// 1D -> C -> C
let data = [1, 2, 3, 4, 5, 6, 7, 8];
let v = aview1(&data);
let u = v.into_shape(((3, 3), Order::RowMajor));
assert!(u.is_err());

let u = v.into_shape(((2, 2, 2), Order::C));
assert!(u.is_ok());

let u = u.unwrap();
assert_eq!(u.shape(), &[2, 2, 2]);
assert_eq!(u, array![[[1, 2], [3, 4]], [[5, 6], [7, 8]]]);

let s = u.into_shape((4, 2)).unwrap();
assert_eq!(s.shape(), &[4, 2]);
assert_eq!(s, aview2(&[[1, 2], [3, 4], [5, 6], [7, 8]]));

// 1D -> F -> F
let data = [1, 2, 3, 4, 5, 6, 7, 8];
let v = aview1(&data);
let u = v.into_shape(((3, 3), Order::ColumnMajor));
assert!(u.is_err());

let u = v.into_shape(((2, 2, 2), Order::ColumnMajor));
assert!(u.is_ok());

let u = u.unwrap();
assert_eq!(u.shape(), &[2, 2, 2]);
assert_eq!(u, array![[[1, 5], [3, 7]], [[2, 6], [4, 8]]]);

let s = u.into_shape(((4, 2), Order::ColumnMajor)).unwrap();
assert_eq!(s.shape(), &[4, 2]);
assert_eq!(s, array![[1, 5], [2, 6], [3, 7], [4, 8]]);
}

0 comments on commit dad6bbd

Please sign in to comment.