Skip to content

Commit

Permalink
autograd: fix gradient of multiplication
Browse files Browse the repository at this point in the history
  • Loading branch information
iblislin committed Oct 8, 2017
1 parent 148118a commit 3146501
Show file tree
Hide file tree
Showing 3 changed files with 66 additions and 27 deletions.
25 changes: 12 additions & 13 deletions src/ndarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -638,34 +638,33 @@ function mul_to!(dst :: NDArray, arg :: Union{Real, NDArray})
return dst
end

mul(x::NDArray, y::NDArray) = _mul(x, y)
mul(x::NDArray, s::Real) = _mul_scalar(x, scalar=s)
mul(s::Real, x::NDArray) = _mul_scalar(x, scalar=s)

import Base: *

"""
.*(arg0, arg1)
Elementwise multiplication of `arg0` and `arg`, could be either scalar or `NDArray`.
"""
@compatdot function Base.broadcast(::typeof(*), arg0 :: NDArray, arg :: Union{Real, NDArray})
ret = copy(arg0, context(arg0))
mul_to!(ret, arg)
@compatdot function Base.broadcast(::typeof(*), arg0 :: NDArray,
arg :: Union{Real, NDArray})
mul(arg0, arg)
end
@compatdot function Base.broadcast(::typeof(*), arg0 :: Real, arg :: NDArray)
arg .* arg0
mul(arg, arg0)
end

"""
*(arg0, arg1)
Currently only multiplication a scalar with an `NDArray` is implemented. Matrix multiplication
is to be added soon.
Currently only multiplication a scalar with an `NDArray` is implemented.
Matrix multiplication is to be added soon.
"""
function *(arg0 :: NDArray, arg :: Real)
ret = copy(arg0, context(arg0))
mul_to!(ret, arg)
end
function *(arg0 :: Real, arg :: NDArray)
*(arg, arg0)
end
*(arg0 :: NDArray, arg :: Real) = mul(arg0, arg)
*(arg0 :: Real, arg :: NDArray) = mul(arg0, arg)

"""
div_from!(dst :: NDArray, arg :: Union{Real, NDArray})
Expand Down
34 changes: 34 additions & 0 deletions test/unittest/autograd.jl
Original file line number Diff line number Diff line change
Expand Up @@ -107,11 +107,45 @@ function test_getsymbol()
end


function test_mul()
info("AutoGrad::mul")

let x = mx.NDArray([1 2; 3 4])
g = mx.attach_grad(x)
y = mx.record() do
2x .* x
end

@test copy(g) == [0 0; 0 0]
@test copy(y) == [2 8; 18 32]

mx.backward(y)
# gradient is 4x
@test copy(g) == [4 8; 12 16]
end

let x = mx.NDArray([1 2; 3 4])
g = mx.attach_grad(x)
y = mx.record() do
x * 2 .* x
end

@test copy(g) == [0 0; 0 0]
@test copy(y) == [2 8; 18 32]

mx.backward(y)
# gradient is 4x
@test copy(g) == [4 8; 12 16]
end
end


@testset "AutoGrad Test" begin
test_getgrad()
test_mark_variables()
test_record()
test_getsymbol()
test_mul()
end


Expand Down
34 changes: 20 additions & 14 deletions test/unittest/ndarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,12 @@ function test_mul()
t6, a6 = rand_tensors(Float16, dims)
scalar_small = Float16(1e-5)
@test reldiff(t6 * scalar_small, copy(a6 .* scalar_small)) < 1e-1

let x = mx.NDArray([1 2; 3 4])
@test eltype(x) == Int
@test copy(1.5x) == [1 2; 3 4]
@test copy(1.9x) == [1 2; 3 4]
end
end

function test_div()
Expand Down Expand Up @@ -392,25 +398,25 @@ function test_eltype()
end

function test_reshape()
info("NDArray::reshape")
A = rand(2, 3, 4)
info("NDArray::reshape")
A = rand(2, 3, 4)

B = reshape(mx.NDArray(A), 4, 3, 2)
@test size(B) == (4, 3, 2)
@test copy(B)[3, 1, 1] == A[1, 2, 1]
B = reshape(mx.NDArray(A), 4, 3, 2)
@test size(B) == (4, 3, 2)
@test copy(B)[3, 1, 1] == A[1, 2, 1]

C = reshape(mx.NDArray(A), (4, 3, 2))
@test size(C) == (4, 3, 2)
@test copy(C)[3, 1, 1] == A[1, 2, 1]
C = reshape(mx.NDArray(A), (4, 3, 2))
@test size(C) == (4, 3, 2)
@test copy(C)[3, 1, 1] == A[1, 2, 1]

info("NDArray::reshape::reverse")
A = mx.zeros(10, 5, 4)
info("NDArray::reshape::reverse")
A = mx.zeros(10, 5, 4)

B = reshape(A, -1, 0)
@test size(B) == (40, 5)
B = reshape(A, -1, 0)
@test size(B) == (40, 5)

C = reshape(A, -1, 0, reverse=true)
@test size(C) == (50, 4)
C = reshape(A, -1, 0, reverse=true)
@test size(C) == (50, 4)
end

function test_kwargs()
Expand Down

0 comments on commit 3146501

Please sign in to comment.