From d037d53016012176b438412349bceb28c5d30760 Mon Sep 17 00:00:00 2001 From: Iblis Lin Date: Thu, 7 Dec 2017 22:08:16 +0800 Subject: [PATCH] ndarray: change internal api of minus to help autograd address https://github.com/dmlc/MXNet.jl/pull/274#issuecomment-349951876 Although this patch cannot pass `@inferred`, but `code_warntype` give me this: ```julia end::MXNet.mx.NDArray{_,_} where _ where _ ``` And seems it doesn't hurt performance. --- src/ndarray.jl | 7 ++++--- test/unittest/ndarray.jl | 6 ------ 2 files changed, 4 insertions(+), 9 deletions(-) diff --git a/src/ndarray.jl b/src/ndarray.jl index 5501357fa..1b670291f 100644 --- a/src/ndarray.jl +++ b/src/ndarray.jl @@ -619,9 +619,10 @@ import Base: - Subtraction `x - y`, of scalar types or `NDArray`. Or create the negative of `x`. """ --(x::NDArray) = _mul_scalar(x, scalar=-one(eltype(x))) --(x::NDArray, y::NDArrayOrReal) = sub_from!(copy(x, context(x)), y) --(x::Real, y::NDArray) = -y .+ x +-(x::NDArray) = _mul_scalar(x, scalar = -one(eltype(x))) +-(x::NDArray, y::NDArray) = _minus(x, y) +-(x::NDArray, y::Real) = _minus_scalar(x, scalar = y) +-(y::Real, x::NDArray) = _rminus_scalar(x, scalar = y) broadcast_(::typeof(-), x::NDArray, y::NDArrayOrReal) = x - y broadcast_(::typeof(-), x::Real, y::NDArray) = x - y diff --git a/test/unittest/ndarray.jl b/test/unittest/ndarray.jl index 1a66034c3..e7261f011 100644 --- a/test/unittest/ndarray.jl +++ b/test/unittest/ndarray.jl @@ -280,12 +280,6 @@ function test_minus() scalar_large = Float16(1e4) @test t6 - scalar_small ≈ copy(a6 .- scalar_small) @test t6 - scalar_large ≈ copy(a6 .- scalar_large) - - info("NDArray::minus::type stablility") - let x = mx.zeros(dims), y = mx.ones(dims) - @inferred x - y - @inferred x .- y - end end function test_mul()