Skip to content

Commit

Permalink
Make aview* free functions be const fns
Browse files Browse the repository at this point in the history
  • Loading branch information
jturner314 authored and bluss committed Mar 10, 2024
1 parent be3219b commit f9bb795
Showing 1 changed file with 83 additions and 9 deletions.
92 changes: 83 additions & 9 deletions src/free_functions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ use alloc::vec;
#[cfg(not(feature = "std"))]
use alloc::vec::Vec;
use std::mem::{forget, size_of};
use std::ptr::NonNull;

use crate::imp_prelude::*;
use crate::{dimension, ArcArray1, ArcArray2};
Expand Down Expand Up @@ -65,32 +66,105 @@ pub fn rcarr1<A: Clone>(xs: &[A]) -> ArcArray1<A> {
}

/// Create a zero-dimensional array view borrowing `x`.
pub fn aview0<A>(x: &A) -> ArrayView0<'_, A> {
unsafe { ArrayView::from_shape_ptr(Ix0(), x) }
pub const fn aview0<A>(x: &A) -> ArrayView0<'_, A> {
ArrayBase {
data: ViewRepr::new(),
// Safe because references are always non-null.
ptr: unsafe { NonNull::new_unchecked(x as *const A as *mut A) },
dim: Ix0(),
strides: Ix0(),
}
}

/// Create a one-dimensional array view with elements borrowing `xs`.
///
/// **Panics** if the length of the slice overflows `isize`. (This can only
/// occur if `A` is zero-sized, because slices cannot contain more than
/// `isize::MAX` number of bytes.)
///
/// ```
/// use ndarray::aview1;
/// use ndarray::{aview1, ArrayView1};
///
/// let data = [1.0; 1024];
///
/// // Create a 2D array view from borrowed data
/// let a2d = aview1(&data).into_shape((32, 32)).unwrap();
///
/// assert_eq!(a2d.sum(), 1024.0);
///
/// // Create a const 1D array view
/// const C: ArrayView1<'static, f64> = aview1(&[1., 2., 3.]);
///
/// assert_eq!(C.sum(), 6.);
/// ```
pub fn aview1<A>(xs: &[A]) -> ArrayView1<'_, A> {
ArrayView::from(xs)
pub const fn aview1<A>(xs: &[A]) -> ArrayView1<'_, A> {
if size_of::<A>() == 0 {
assert!(
xs.len() <= isize::MAX as usize,
"Slice length must fit in `isize`.",
);
}
ArrayBase {
data: ViewRepr::new(),
// Safe because references are always non-null.
ptr: unsafe { NonNull::new_unchecked(xs.as_ptr() as *mut A) },
dim: Ix1(xs.len()),
strides: Ix1(1),
}
}

/// Create a two-dimensional array view with elements borrowing `xs`.
///
/// **Panics** if the product of non-zero axis lengths overflows `isize` (This can only occur if A
/// is zero-sized because slices cannot contain more than `isize::MAX` number of bytes).
pub fn aview2<A, const N: usize>(xs: &[[A; N]]) -> ArrayView2<'_, A> {
ArrayView2::from(xs)
/// **Panics** if the product of non-zero axis lengths overflows `isize` (This
/// can only occur if A is zero-sized or if `N` is zero, because slices cannot
/// contain more than `isize::MAX` number of bytes).
///
/// ```
/// use ndarray::{aview2, ArrayView2};
///
/// let data = vec![[1., 2., 3.], [4., 5., 6.]];
///
/// let view = aview2(&data);
/// assert_eq!(view.sum(), 21.);
///
/// // Create a const 2D array view
/// const C: ArrayView2<'static, f64> = aview2(&[[1., 2., 3.], [4., 5., 6.]]);
/// assert_eq!(C.sum(), 21.);
/// ```
pub const fn aview2<A, const N: usize>(xs: &[[A; N]]) -> ArrayView2<'_, A> {
let cols = N;
let rows = xs.len();
if size_of::<A>() == 0 {
if let Some(n_elems) = rows.checked_mul(cols) {
assert!(
rows <= isize::MAX as usize
&& cols <= isize::MAX as usize
&& n_elems <= isize::MAX as usize,
"Product of non-zero axis lengths must not overflow isize.",
);
} else {
panic!("Overflow in number of elements.");
}
} else if N == 0 {
assert!(
rows <= isize::MAX as usize,
"Product of non-zero axis lengths must not overflow isize.",
);
}
// Safe because references are always non-null.
let ptr = unsafe { NonNull::new_unchecked(xs.as_ptr() as *mut A) };
let dim = Ix2(rows, cols);
let strides = if rows == 0 || cols == 0 {
Ix2(0, 0)
} else {
Ix2(cols, 1)
};
ArrayBase {
data: ViewRepr::new(),
ptr,
dim,
strides,
}
}

/// Create a one-dimensional read-write array view with elements borrowing `xs`.
Expand Down

0 comments on commit f9bb795

Please sign in to comment.