Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement arithmetic ops on more combinations of types #744

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
111 changes: 76 additions & 35 deletions src/impl_ops.rs
Original file line number Diff line number Diff line change
Expand Up @@ -53,9 +53,7 @@ macro_rules! impl_binary_op(
/// Perform elementwise
#[doc=$doc]
/// between `self` and `rhs`,
/// and return the result (based on `self`).
///
/// `self` must be an `Array` or `ArcArray`.
/// and return the result.
///
/// If their shapes disagree, `rhs` is broadcast to the shape of `self`.
///
Expand All @@ -64,13 +62,13 @@ impl<A, B, S, S2, D, E> $trt<ArrayBase<S2, E>> for ArrayBase<S, D>
where
A: Clone + $trt<B, Output=A>,
B: Clone,
S: DataOwned<Elem=A> + DataMut,
S: Data<Elem=A>,
S2: Data<Elem=B>,
D: Dimension,
E: Dimension,
{
type Output = ArrayBase<S, D>;
fn $mth(self, rhs: ArrayBase<S2, E>) -> ArrayBase<S, D>
type Output = Array<A, D>;
fn $mth(self, rhs: ArrayBase<S2, E>) -> Array<A, D>
{
self.$mth(&rhs)
}
Expand All @@ -79,7 +77,7 @@ where
/// Perform elementwise
#[doc=$doc]
/// between `self` and reference `rhs`,
/// and return the result (based on `self`).
/// and return the result.
///
/// If their shapes disagree, `rhs` is broadcast to the shape of `self`.
///
Expand All @@ -88,18 +86,19 @@ impl<'a, A, B, S, S2, D, E> $trt<&'a ArrayBase<S2, E>> for ArrayBase<S, D>
where
A: Clone + $trt<B, Output=A>,
B: Clone,
S: DataOwned<Elem=A> + DataMut,
S: Data<Elem=A>,
S2: Data<Elem=B>,
D: Dimension,
E: Dimension,
{
type Output = ArrayBase<S, D>;
fn $mth(mut self, rhs: &ArrayBase<S2, E>) -> ArrayBase<S, D>
type Output = Array<A, D>;
fn $mth(self, rhs: &ArrayBase<S2, E>) -> Array<A, D>
{
self.zip_mut_with(rhs, |x, y| {
let mut lhs = self.into_owned();
lhs.zip_mut_with(rhs, |x, y| {
*x = x.clone() $operator y.clone();
});
self
lhs
}
}

Expand Down Expand Up @@ -129,22 +128,45 @@ where

/// Perform elementwise
#[doc=$doc]
/// between `self` and the scalar `x`,
/// and return the result (based on `self`).
/// between `self` and `rhs`,
/// and return the result as a new `Array`.
///
/// If their shapes disagree, `rhs` is broadcast to the shape of `self`.
///
/// `self` must be an `Array` or `ArcArray`.
/// **Panics** if broadcasting isn’t possible.
impl<'a, A, B, S, S2, D, E> $trt<ArrayBase<S2, E>> for &'a ArrayBase<S, D>
where
A: Clone + $trt<B, Output=A>,
B: Clone,
S: Data<Elem=A>,
S2: Data<Elem=B>,
D: Dimension,
E: Dimension,
{
type Output = Array<A, D>;
fn $mth(self, rhs: ArrayBase<S2, E>) -> Array<A, D> {
// FIXME: Can we co-broadcast arrays here? And how?
self.to_owned().$mth(rhs)
}
}

/// Perform elementwise
#[doc=$doc]
/// between `self` and the scalar `x`,
/// and return the result.
impl<A, S, D, B> $trt<B> for ArrayBase<S, D>
where A: Clone + $trt<B, Output=A>,
S: DataOwned<Elem=A> + DataMut,
S: Data<Elem=A>,
D: Dimension,
B: ScalarOperand,
{
type Output = ArrayBase<S, D>;
fn $mth(mut self, x: B) -> ArrayBase<S, D> {
self.unordered_foreach_mut(move |elt| {
type Output = Array<A, D>;
fn $mth(self, x: B) -> Array<A, D> {
let mut lhs = self.into_owned();
lhs.unordered_foreach_mut(move |elt| {
*elt = elt.clone() $operator x.clone();
});
self
lhs
}
}

Expand Down Expand Up @@ -183,17 +205,17 @@ macro_rules! impl_scalar_lhs_op {
// these have no doc -- they are not visible in rustdoc
// Perform elementwise
// between the scalar `self` and array `rhs`,
// and return the result (based on `self`).
// and return the result.
impl<S, D> $trt<ArrayBase<S, D>> for $scalar
where S: DataOwned<Elem=$scalar> + DataMut,
where S: Data<Elem=$scalar>,
D: Dimension,
{
type Output = ArrayBase<S, D>;
fn $mth(self, rhs: ArrayBase<S, D>) -> ArrayBase<S, D> {
type Output = Array<$scalar, D>;
fn $mth(self, rhs: ArrayBase<S, D>) -> Array<$scalar, D> {
if_commutative!($commutative {
rhs.$mth(self)
} or {{
let mut rhs = rhs;
let mut rhs = rhs.into_owned();
rhs.unordered_foreach_mut(move |elt| {
*elt = self $operator *elt;
});
Expand Down Expand Up @@ -293,16 +315,17 @@ mod arithmetic_ops {
impl<A, S, D> Neg for ArrayBase<S, D>
where
A: Clone + Neg<Output = A>,
S: DataOwned<Elem = A> + DataMut,
S: Data<Elem = A>,
D: Dimension,
{
type Output = Self;
type Output = Array<A, D>;
/// Perform an elementwise negation of `self` and return the result.
fn neg(mut self) -> Self {
self.unordered_foreach_mut(|elt| {
fn neg(self) -> Array<A, D> {
let mut array = self.into_owned();
array.unordered_foreach_mut(|elt| {
*elt = -elt.clone();
});
self
array
}
}

Expand All @@ -323,16 +346,17 @@ mod arithmetic_ops {
impl<A, S, D> Not for ArrayBase<S, D>
where
A: Clone + Not<Output = A>,
S: DataOwned<Elem = A> + DataMut,
S: Data<Elem = A>,
D: Dimension,
{
type Output = Self;
type Output = Array<A, D>;
/// Perform an elementwise unary not of `self` and return the result.
fn not(mut self) -> Self {
self.unordered_foreach_mut(|elt| {
fn not(self) -> Array<A, D> {
let mut array = self.into_owned();
array.unordered_foreach_mut(|elt| {
*elt = !elt.clone();
});
self
array
}
}

Expand All @@ -359,6 +383,23 @@ mod assign_ops {
($trt:ident, $method:ident, $doc:expr) => {
use std::ops::$trt;

#[doc=$doc]
/// If their shapes disagree, `rhs` is broadcast to the shape of `self`.
///
/// **Panics** if broadcasting isn’t possible.
impl<A, S, S2, D, E> $trt<ArrayBase<S2, E>> for ArrayBase<S, D>
where
A: Clone + $trt<A>,
S: DataMut<Elem = A>,
S2: Data<Elem = A>,
D: Dimension,
E: Dimension,
{
fn $method(&mut self, rhs: ArrayBase<S2, E>) {
self.$method(&rhs)
}
}

#[doc=$doc]
/// If their shapes disagree, `rhs` is broadcast to the shape of `self`.
///
Expand Down
31 changes: 13 additions & 18 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -607,18 +607,14 @@ pub type Ixs = isize;
///
/// ### Binary Operators with Two Arrays
///
/// Let `A` be an array or view of any kind. Let `B` be an array
/// with owned storage (either `Array` or `ArcArray`).
/// Let `C` be an array with mutable data (either `Array`, `ArcArray`
/// or `ArrayViewMut`).
/// The following combinations of operands
/// are supported for an arbitrary binary operator denoted by `@` (it can be
/// `+`, `-`, `*`, `/` and so on).
///
/// - `&A @ &A` which produces a new `Array`
/// - `B @ A` which consumes `B`, updates it with the result, and returns it
/// - `B @ &A` which consumes `B`, updates it with the result, and returns it
/// - `C @= &A` which performs an arithmetic operation in place
/// Let `A` be an array or view of any kind. Let `M` be an array with mutable
/// data (either `Array`, `ArcArray` or `ArrayViewMut`). The following
/// combinations of operands are supported for an arbitrary binary operator
/// denoted by `@` (it can be `+`, `-`, `*`, `/` and so on).
///
/// - `&A @ &A` or `&A @ A` which produce a new `Array`
/// - `A @ &A` or `A @ A` which may reuse the allocation of the LHS if it's an owned array
/// - `M @= &A` or `M @= A` which performs an arithmetic operation in place on `M`
///
/// Note that the element type needs to implement the operator trait and the
/// `Clone` trait.
Expand Down Expand Up @@ -647,17 +643,16 @@ pub type Ixs = isize;
/// `ScalarOperand` docs has the detailed condtions).
///
/// - `&A @ K` or `K @ &A` which produces a new `Array`
/// - `B @ K` or `K @ B` which consumes `B`, updates it with the result and returns it
/// - `C @= K` which performs an arithmetic operation in place
/// - `A @ K` or `K @ A` which may reuse the allocation of the array if it's an owned array
/// - `M @= K` which performs an arithmetic operation in place
///
/// ### Unary Operators
///
/// Let `A` be an array or view of any kind. Let `B` be an array with owned
/// storage (either `Array` or `ArcArray`). The following operands are supported
/// for an arbitrary unary operator denoted by `@` (it can be `-` or `!`).
/// The following operands are supported for an arbitrary unary operator
/// denoted by `@` (it can be `-` or `!`).
///
/// - `@&A` which produces a new `Array`
/// - `@B` which consumes `B`, updates it with the result, and returns it
/// - `@A` which may reuse the allocation of the array if it's an owned array
///
/// ## Broadcasting
///
Expand Down
10 changes: 5 additions & 5 deletions tests/array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -394,11 +394,11 @@ fn test_add() {
}

let B = A.clone();
A = A + &B;
assert_eq!(A[[0, 0]], 0);
assert_eq!(A[[0, 1]], 2);
assert_eq!(A[[1, 0]], 4);
assert_eq!(A[[1, 1]], 6);
let C = A + &B;
assert_eq!(C[[0, 0]], 0);
assert_eq!(C[[0, 1]], 2);
assert_eq!(C[[1, 0]], 4);
assert_eq!(C[[1, 1]], 6);
}

#[test]
Expand Down