diff --git a/src/ndarray.jl b/src/ndarray.jl index 5501357fa..1b670291f 100644 --- a/src/ndarray.jl +++ b/src/ndarray.jl @@ -619,9 +619,10 @@ import Base: - Subtraction `x - y`, of scalar types or `NDArray`. Or create the negative of `x`. """ --(x::NDArray) = _mul_scalar(x, scalar=-one(eltype(x))) --(x::NDArray, y::NDArrayOrReal) = sub_from!(copy(x, context(x)), y) --(x::Real, y::NDArray) = -y .+ x +-(x::NDArray) = _mul_scalar(x, scalar = -one(eltype(x))) +-(x::NDArray, y::NDArray) = _minus(x, y) +-(x::NDArray, y::Real) = _minus_scalar(x, scalar = y) +-(y::Real, x::NDArray) = _rminus_scalar(x, scalar = y) broadcast_(::typeof(-), x::NDArray, y::NDArrayOrReal) = x - y broadcast_(::typeof(-), x::Real, y::NDArray) = x - y diff --git a/test/unittest/ndarray.jl b/test/unittest/ndarray.jl index 1a66034c3..e7261f011 100644 --- a/test/unittest/ndarray.jl +++ b/test/unittest/ndarray.jl @@ -280,12 +280,6 @@ function test_minus() scalar_large = Float16(1e4) @test t6 - scalar_small ≈ copy(a6 .- scalar_small) @test t6 - scalar_large ≈ copy(a6 .- scalar_large) - - info("NDArray::minus::type stablility") - let x = mx.zeros(dims), y = mx.ones(dims) - @inferred x - y - @inferred x .- y - end end function test_mul()