Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add shrink_to_fit to Array #1141

Open
wants to merge 7 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 31 additions & 24 deletions src/data_repr.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
use crate::extension::nonnull;
use alloc::borrow::ToOwned;
use alloc::slice;
use alloc::vec::Vec;
use std::mem;
use std::mem::ManuallyDrop;
use std::ptr::NonNull;
use alloc::slice;
use alloc::borrow::ToOwned;
use alloc::vec::Vec;
use crate::extension::nonnull;

use rawpointer::PointerExt;

Expand All @@ -30,24 +30,20 @@ impl<A> OwnedRepr<A> {
let len = v.len();
let capacity = v.capacity();
let ptr = nonnull::nonnull_from_vec_data(&mut v);
Self {
ptr,
len,
capacity,
}
Self { ptr, len, capacity }
}

pub(crate) fn into_vec(self) -> Vec<A> {
ManuallyDrop::new(self).take_as_vec()
}

pub(crate) fn as_slice(&self) -> &[A] {
unsafe {
slice::from_raw_parts(self.ptr.as_ptr(), self.len)
}
unsafe { slice::from_raw_parts(self.ptr.as_ptr(), self.len) }
}

pub(crate) fn len(&self) -> usize { self.len }
pub(crate) fn len(&self) -> usize {
self.len
}

pub(crate) fn as_ptr(&self) -> *const A {
self.ptr.as_ptr()
Expand All @@ -63,13 +59,11 @@ impl<A> OwnedRepr<A> {

/// Return end pointer
pub(crate) fn as_end_nonnull(&self) -> NonNull<A> {
unsafe {
self.ptr.add(self.len)
}
unsafe { self.ptr.add(self.len) }
}

/// Reserve `additional` elements; return the new pointer
///
///
/// ## Safety
///
/// Note that existing pointers into the data are invalidated
Expand All @@ -82,6 +76,21 @@ impl<A> OwnedRepr<A> {
self.as_nonnull_mut()
}

/// Shrink the capacity of the array with lower bound.
/// The capacity will remain at least as large as both the length and
/// supplied value.
/// If the current capacity is less than the lower limit, this is a no-op.
pub(crate) fn shrink_to_fit(&mut self, len: usize) {
if len < self.len {
self.len = len;
self.modify_as_vec(|mut v| {
v.shrink_to_fit();
v
});
self.capacity = len;
}
}

/// Set the valid length of the data
///
/// ## Safety
Expand Down Expand Up @@ -126,14 +135,13 @@ impl<A> OwnedRepr<A> {
let len = self.len;
self.len = 0;
self.capacity = 0;
unsafe {
Vec::from_raw_parts(self.ptr.as_ptr(), len, capacity)
}
unsafe { Vec::from_raw_parts(self.ptr.as_ptr(), len, capacity) }
}
}

impl<A> Clone for OwnedRepr<A>
where A: Clone
where
A: Clone,
{
fn clone(&self) -> Self {
Self::from(self.as_slice().to_owned())
Expand Down Expand Up @@ -174,6 +182,5 @@ impl<A> Drop for OwnedRepr<A> {
}
}

unsafe impl<A> Sync for OwnedRepr<A> where A: Sync { }
unsafe impl<A> Send for OwnedRepr<A> where A: Send { }

unsafe impl<A> Sync for OwnedRepr<A> where A: Sync {}
unsafe impl<A> Send for OwnedRepr<A> where A: Send {}
130 changes: 107 additions & 23 deletions src/impl_owned_array.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@

use alloc::vec::Vec;
use std::mem;
use std::mem::MaybeUninit;
Expand All @@ -11,6 +10,7 @@ use crate::dimension;
use crate::error::{ErrorKind, ShapeError};
use crate::iterators::Baseiter;
use crate::low_level_util::AbortIfPanic;
use crate::ArrayViewMut;
use crate::OwnedRepr;
use crate::Zip;

Expand Down Expand Up @@ -166,7 +166,8 @@ impl<A> Array<A, Ix2> {
}

impl<A, D> Array<A, D>
where D: Dimension
where
D: Dimension,
{
/// Move all elements from self into `new_array`, which must be of the same shape but
/// can have a different memory layout. The destination is overwritten completely.
Expand Down Expand Up @@ -198,9 +199,7 @@ impl<A, D> Array<A, D>
} else {
// If `A` doesn't need drop, we can overwrite the destination.
// Safe because: move_into_uninit only writes initialized values
unsafe {
self.move_into_uninit(new_array.into_maybe_uninit())
}
unsafe { self.move_into_uninit(new_array.into_maybe_uninit()) }
}
}

Expand All @@ -209,7 +208,8 @@ impl<A, D> Array<A, D>
// Afterwards, `self` drops full of initialized values and dropping works as usual.
// This avoids moving out of owned values in `self` while at the same time managing
// the dropping if the values being overwritten in `new_array`.
Zip::from(&mut self).and(new_array)
Zip::from(&mut self)
.and(new_array)
.for_each(|src, dst| mem::swap(src, dst));
}

Expand Down Expand Up @@ -401,16 +401,87 @@ impl<A, D> Array<A, D>
/// [0., 0., 0., 0.],
/// [1., 1., 1., 1.]]);
/// ```
pub fn push(&mut self, axis: Axis, array: ArrayView<A, D::Smaller>)
-> Result<(), ShapeError>
pub fn push(&mut self, axis: Axis, array: ArrayView<A, D::Smaller>) -> Result<(), ShapeError>
where
A: Clone,
D: RemoveAxis,
{
// same-dimensionality conversion
self.append(axis, array.insert_axis(axis).into_dimensionality::<D>().unwrap())
self.append(
axis,
array.insert_axis(axis).into_dimensionality::<D>().unwrap(),
)
}

/// Calculates the distance from `self.ptr` to `self.data.ptr`.
unsafe fn offset_from_data_ptr_to_logical_ptr(&self) -> isize {
if std::mem::size_of::<A>() != 0 {
self.as_ptr().offset_from(self.data.as_ptr())
} else {
0
}
}

/// Shrinks the capacity of the array as much as possible.
///
/// ```
/// use ndarray::array;
/// use ndarray::s;
///
/// let a = array![[0, 1, 2], [3, 4, 5], [6, 7,8]];
/// let mut a = a.slice_move(s![.., 0..2]);
/// let b = a.clone();
/// a.shrink_to_fit();
/// assert_eq!(a, b);
/// ```
pub fn shrink_to_fit(&mut self)
where
A: Copy,
{
let dim = self.dim.clone();
let strides = self.strides.clone();
let mut shrinked_stride = D::zeros(dim.ndim());

// Calculate the new stride after shrink
// Even after shrink, the order of stride size is maintained.
// For example, if dim is [3, 2, 3] and stride is [1, 9, 3], the default
// stride will be [6, 3, 1], but the stride order will be [1, 6, 3]
// because the size order of the original stride is maintained.
let mut stride_order = (0..dim.ndim()).collect::<Vec<_>>();
stride_order.sort_unstable_by(|&i, &j| strides[i].cmp(&strides[j]));
let mut stride_ = 1;
for i in stride_order.iter() {
shrinked_stride[*i] = stride_;
stride_ *= dim[*i];
}

// Calculate which index in shrinked_stride from the pointer offset.
let mut stride_order_order = (0..dim.ndim()).collect::<Vec<_>>();
stride_order_order.sort_unstable_by(|&i, &j| stride_order[j].cmp(&stride_order[i]));
let offset_stride = |offset: usize| {
let mut offset = offset;
let mut index = D::zeros(dim.ndim());
for i in stride_order_order.iter() {
index[*i] = offset / shrinked_stride[*i];
offset %= shrinked_stride[*i];
}
index
};

// Change the memory order only if it needs to be changed.
let ptr_offset = unsafe { self.offset_from_data_ptr_to_logical_ptr() };
self.ptr = unsafe { self.ptr.sub(ptr_offset as usize) };
for offset in 0..self.len() {
let index = offset_stride(offset);
let old_offset = dim.stride_offset_checked(&strides, &index).unwrap();
unsafe {
*self.ptr.as_ptr().add(offset as usize) =
*self.ptr.as_ptr().add((old_offset + ptr_offset) as usize);
}
}
self.strides = shrinked_stride;
self.data.shrink_to_fit(self.len());
}

/// Append an array to the array along an axis.
///
Expand Down Expand Up @@ -462,8 +533,7 @@ impl<A, D> Array<A, D>
/// [1., 1., 1., 1.],
/// [1., 1., 1., 1.]]);
/// ```
pub fn append(&mut self, axis: Axis, mut array: ArrayView<A, D>)
-> Result<(), ShapeError>
pub fn append(&mut self, axis: Axis, mut array: ArrayView<A, D>) -> Result<(), ShapeError>
where
A: Clone,
D: RemoveAxis,
Expand Down Expand Up @@ -556,7 +626,11 @@ impl<A, D> Array<A, D>
acc
} else {
let this_ax = ax.len as isize * ax.stride.abs();
if this_ax > acc { this_ax } else { acc }
if this_ax > acc {
this_ax
} else {
acc
}
}
});
let mut strides = self.strides.clone();
Expand All @@ -574,7 +648,10 @@ impl<A, D> Array<A, D>
0
};
debug_assert!(data_to_array_offset >= 0);
self.ptr = self.data.reserve(len_to_append).offset(data_to_array_offset);
self.ptr = self
.data
.reserve(len_to_append)
.offset(data_to_array_offset);

// clone elements from view to the array now
//
Expand Down Expand Up @@ -608,10 +685,13 @@ impl<A, D> Array<A, D>

if tail_view.ndim() > 1 {
sort_axes_in_default_order_tandem(&mut tail_view, &mut array);
debug_assert!(tail_view.is_standard_layout(),
"not std layout dim: {:?}, strides: {:?}",
tail_view.shape(), tail_view.strides());
}
debug_assert!(
tail_view.is_standard_layout(),
"not std layout dim: {:?}, strides: {:?}",
tail_view.shape(),
tail_view.strides()
);
}

// Keep track of currently filled length of `self.data` and update it
// on scope exit (panic or loop finish). This "indirect" way to
Expand All @@ -635,7 +715,6 @@ impl<A, D> Array<A, D>
data: &mut self.data,
};


// Safety: tail_view is constructed to have the same shape as array
Zip::from(tail_view)
.and_unchecked(array)
Expand Down Expand Up @@ -665,8 +744,11 @@ impl<A, D> Array<A, D>
///
/// This is an internal function for use by move_into and IntoIter only, safety invariants may need
/// to be upheld across the calls from those implementations.
pub(crate) unsafe fn drop_unreachable_raw<A, D>(mut self_: RawArrayViewMut<A, D>, data_ptr: *mut A, data_len: usize)
where
pub(crate) unsafe fn drop_unreachable_raw<A, D>(
mut self_: RawArrayViewMut<A, D>,
data_ptr: *mut A,
data_len: usize,
) where
D: Dimension,
{
let self_len = self_.len();
Expand Down Expand Up @@ -731,8 +813,11 @@ where
dropped_elements += 1;
}

assert_eq!(data_len, dropped_elements + self_len,
"Internal error: inconsistency in move_into");
assert_eq!(
data_len,
dropped_elements + self_len,
"Internal error: inconsistency in move_into"
);
}

/// Sort axes to standard order, i.e Axis(0) has biggest stride and Axis(n - 1) least stride
Expand Down Expand Up @@ -774,7 +859,6 @@ where
}
}


/// Sort axes to standard order, i.e Axis(0) has biggest stride and Axis(n - 1) least stride
///
/// Axes in a and b are sorted by the strides of `a`, and `a`'s axes should have stride >= 0 before
adamreichold marked this conversation as resolved.
Show resolved Hide resolved
Expand Down
55 changes: 55 additions & 0 deletions tests/shrink_to_fit.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
use ndarray::{s, Array};

#[test]
fn dim_0() {
let mut raw_vec = Vec::new();
for i in 0..4 * 5 * 6 {
raw_vec.push(i);
}
let a = Array::from_shape_vec((4, 5, 6), raw_vec).unwrap();
let mut a_slice = a.slice_move(s![0..2, .., ..]);
let a_slice_clone = a_slice.view().to_owned();
a_slice.shrink_to_fit();
assert_eq!(a_slice, a_slice_clone);
}

#[test]
fn swap_axis_dim_0() {
let mut raw_vec = Vec::new();
for i in 0..4 * 5 * 6 {
raw_vec.push(i);
}
let mut a = Array::from_shape_vec((4, 5, 6), raw_vec).unwrap();
a.swap_axes(0, 1);
let mut a_slice = a.slice_move(s![2..3, .., ..]);
let a_slice_clone = a_slice.view().to_owned();
a_slice.shrink_to_fit();
assert_eq!(a_slice, a_slice_clone);
}

#[test]
fn swap_axis_dim() {
let mut raw_vec = Vec::new();
for i in 0..4 * 5 * 6 {
raw_vec.push(i);
}
let mut a = Array::from_shape_vec((4, 5, 6), raw_vec).unwrap();
a.swap_axes(2, 1);
let mut a_slice = a.slice_move(s![2..3, 0..3, 0..;2]);
let a_slice_clone = a_slice.view().to_owned();
a_slice.shrink_to_fit();
assert_eq!(a_slice, a_slice_clone);
}

#[test]
fn stride_negative() {
let mut raw_vec = Vec::new();
for i in 0..4 * 5 * 6 {
raw_vec.push(i);
}
let a = Array::from_shape_vec((4, 5, 6), raw_vec).unwrap();
let mut a_slice = a.slice_move(s![2..3, 0..3, 0..;-1]);
let a_slice_clone = a_slice.view().to_owned();
a_slice.shrink_to_fit();
assert_eq!(a_slice, a_slice_clone);
}