Skip to content

Commit

Permalink
ndarray: change internal api of minus to help autograd
Browse files Browse the repository at this point in the history
address #274 (comment)

Although this patch cannot pass `@inferred`, but `code_warntype` give me
this:
```julia
end::MXNet.mx.NDArray{_,_} where _ where _
```

And seems it doesn't hurt performance.
  • Loading branch information
iblislin committed Dec 7, 2017
1 parent f8d4f62 commit 76e0fab
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 9 deletions.
7 changes: 4 additions & 3 deletions src/ndarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -618,9 +618,10 @@ import Base: -
Subtraction `x - y`, of scalar types or `NDArray`.
Or create the negative of `x`.
"""
-(x::NDArray) = _mul_scalar(x, scalar=-one(eltype(x)))
-(x::NDArray, y::NDArrayOrReal) = sub_from!(copy(x, context(x)), y)
-(x::Real, y::NDArray) = -y .+ x
-(x::NDArray) = _mul_scalar(x, scalar = -one(eltype(x)))
-(x::NDArray, y::NDArray) = _minus(x, y)
-(x::NDArray, y::Real) = _minus_scalar(x, scalar = y)
-(y::Real, x::NDArray) = _rminus_scalar(x, scalar = y)

broadcast_(::typeof(-), x::NDArray, y::NDArrayOrReal) = x - y
broadcast_(::typeof(-), x::Real, y::NDArray) = x - y
Expand Down
6 changes: 0 additions & 6 deletions test/unittest/ndarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -285,12 +285,6 @@ function test_minus()
let x = mx.NDArray([1, 2, 3])
@test copy(x .- π) [-2, -1, 0]
end

info("NDArray::minus::type stablility")
let x = mx.zeros(dims), y = mx.ones(dims)
@inferred x - y
@inferred x .- y
end
end

function test_mul()
Expand Down

0 comments on commit 76e0fab

Please sign in to comment.