-
Notifications
You must be signed in to change notification settings - Fork 307
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Impl a lifetime-relaxed broadcast for ArrayView
ArrayView::broadcast has a lifetime that depends on &self instead of its internal buffer. This prevents writing some types of functions in an allocation-free way. For instance, take the numpy `meshgrid` function: It could be implemented like so: ```rust fn meshgrid_2d<'a, 'b>(coords_x: ArrayView1<'a, X>, coords_y: ArrayView1<'b, X>) -> (ArrayView2<'a, X>, ArrayView2<'b, X>) { let x_len = coords_x.shape()[0]; let y_len = coords_y.shape()[0]; let coords_x_s = coords_x.into_shape((1, y_len)).unwrap(); let coords_x_b = coords_x_s.broadcast((x_len, y_len)).unwrap(); let coords_y_s = coords_y.into_shape((x_len, 1)).unwrap(); let coords_y_b = coords_y_s.broadcast((x_len, y_len)).unwrap(); (coords_x_b, coords_y_b) } ``` Unfortunately, this doesn't work, because `coords_x_b` is bound to the lifetime of `coord_x_s`, instead of being bound to 'a. This commit introduces a new function, broadcast_ref, that does just that.
- Loading branch information
Showing
2 changed files
with
85 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,84 @@ | ||
// Copyright 2014-2016 bluss and ndarray developers. | ||
// | ||
// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or | ||
// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license | ||
// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your | ||
// option. This file may not be copied, modified, or distributed | ||
// except according to those terms. | ||
|
||
use crate::imp_prelude::*; | ||
use crate::dimension::IntoDimension; | ||
use crate::dimension::size_of_shape_checked; | ||
|
||
impl<'a, A, D> ArrayView<'a, A, D> | ||
where | ||
D: Dimension, | ||
{ | ||
/// Broadcasts an `ArrayView`. See [`ArrayBase::broadcast`]. | ||
/// | ||
/// This is a specialized version of [`ArrayBase::broadcast`] that transfers | ||
/// the view's lifetime to the output. | ||
pub fn broadcast_ref<E>(&self, dim: E) -> Option<ArrayView<'a, A, E::Dim>> | ||
where | ||
E: IntoDimension, | ||
{ | ||
/// Return new stride when trying to grow `from` into shape `to` | ||
/// | ||
/// Broadcasting works by returning a "fake stride" where elements | ||
/// to repeat are in axes with 0 stride, so that several indexes point | ||
/// to the same element. | ||
/// | ||
/// **Note:** Cannot be used for mutable iterators, since repeating | ||
/// elements would create aliasing pointers. | ||
fn upcast<D: Dimension, E: Dimension>(to: &D, from: &E, stride: &E) -> Option<D> { | ||
// Make sure the product of non-zero axis lengths does not exceed | ||
// `isize::MAX`. This is the only safety check we need to perform | ||
// because all the other constraints of `ArrayBase` are guaranteed | ||
// to be met since we're starting from a valid `ArrayBase`. | ||
let _ = size_of_shape_checked(to).ok()?; | ||
|
||
let mut new_stride = to.clone(); | ||
// begin at the back (the least significant dimension) | ||
// size of the axis has to either agree or `from` has to be 1 | ||
if to.ndim() < from.ndim() { | ||
return None; | ||
} | ||
|
||
{ | ||
let mut new_stride_iter = new_stride.slice_mut().iter_mut().rev(); | ||
for ((er, es), dr) in from | ||
.slice() | ||
.iter() | ||
.rev() | ||
.zip(stride.slice().iter().rev()) | ||
.zip(new_stride_iter.by_ref()) | ||
{ | ||
/* update strides */ | ||
if *dr == *er { | ||
/* keep stride */ | ||
*dr = *es; | ||
} else if *er == 1 { | ||
/* dead dimension, zero stride */ | ||
*dr = 0 | ||
} else { | ||
return None; | ||
} | ||
} | ||
|
||
/* set remaining strides to zero */ | ||
for dr in new_stride_iter { | ||
*dr = 0; | ||
} | ||
} | ||
Some(new_stride) | ||
} | ||
let dim = dim.into_dimension(); | ||
|
||
// Note: zero strides are safe precisely because we return an read-only view | ||
let broadcast_strides = match upcast(&dim, &self.dim, &self.strides) { | ||
Some(st) => st, | ||
None => return None, | ||
}; | ||
unsafe { Some(ArrayView::new(self.ptr, dim, broadcast_strides)) } | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,7 @@ | ||
mod constructors; | ||
mod conversions; | ||
mod indexing; | ||
mod methods; | ||
mod splitting; | ||
|
||
pub use constructors::*; | ||
|