Skip to content

Commit

Permalink
autograd: add test cases for get_symbol
Browse files Browse the repository at this point in the history
  • Loading branch information
iblislin committed Oct 7, 2017
1 parent fe228c7 commit 7c6c3de
Show file tree
Hide file tree
Showing 2 changed files with 93 additions and 8 deletions.
52 changes: 44 additions & 8 deletions src/autograd.jl
Original file line number Diff line number Diff line change
Expand Up @@ -223,10 +223,28 @@ function backward(heads::Vector{NDArray}, head_grads=Union{Vector, Void};
output_handles = map(arr -> arr.handle, heads)

if head_grads == nothing
@mxcall(:MXAutogradBackwardEx,
(MX_uint, Ptr{MX_handle}, Ptr{MX_handle}, Cint, Cint),
length(output_handles), output_handles, C_NULL,
retain_graph, train_mode)
@mxcall(
:MXAutogradBackwardEx,
(MX_uint,
Ptr{MX_handle},
Ptr{MX_handle},
MX_uint,
Ptr{MX_handle},
Cint,
Cint,
Cint,
Ptr{MX_handle},
Ptr{MX_handle}),
length(output_handles),
output_handles,
C_NULL,
0,
C_NULL,
retain_graph,
false, # create_graph
train_mode,
C_NULL,
C_NULL)
return
end

Expand All @@ -240,10 +258,28 @@ function backward(heads::Vector{NDArray}, head_grads=Union{Vector, Void};
end
end
@assert length(output_handles) == length(ograd_handles)
@mxcall(:MXAutogradBackwardEx,
(MX_uint, Ptr{MX_handle}, Ptr{MX_handle}, Cint, Cint),
length(output_handles), output_handles, ograd_handles,
retain_graph, train_mode)
@mxcall(
:MXAutogradBackwardEx,
(MX_uint,
Ptr{MX_handle},
Ptr{MX_handle},
MX_uint,
Ptr{MX_handle},
Cint,
Cint,
Cint,
Ptr{MX_handle},
Ptr{MX_handle}),
length(output_handles),
output_handles,
ograd_handles,
0,
C_NULL,
retain_graph,
false, # create_graph
train_mode,
C_NULL,
C_NULL)
end

"""
Expand Down
49 changes: 49 additions & 0 deletions test/unittest/autograd.jl
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,54 @@ function test_mark_variables()
end


function test_record()
let x = mx.NDArray([1 2; 3 4])
info("AutoGrad::record::backward")

mx.attach_grad(x)
y = mx.record() do
mx.square(x)
end

@test copy(y) == [1 4; 9 16]

mx.backward(y)
# gradient is 2x
@test copy(mx.grad(x)) == [2 4; 6 8]
end

let x = mx.NDArray([1 2; 3 4])
info("AutoGrad::record::get_symbol")

mx.attach_grad(x)
y = mx.record() do
mx.square(x)
end

@test copy(y) == [1 4; 9 16]

@test isa(mx.get_symbol(y), mx.SymbolicNode)
end

let x = mx.NDArray([1 2; 3 4])
info("AutoGrad::record::backward(retain_graph=true)")

mx.attach_grad(x)
y = mx.record() do
mx.square(x)
end

@test copy(y) == [1 4; 9 16]

mx.backward(y, retain_graph=true)
# gradient is 2x
@test copy(mx.grad(x)) == [2 4; 6 8]

@test isa(mx.get_symbol(y), mx.SymbolicNode)
end
end # function test_record()


function test_get_symbol()
info("AutoGrad::get_symbol")

Expand All @@ -62,6 +110,7 @@ end
@testset "AutoGrad Test" begin
test_grad()
test_mark_variables()
test_record()
test_get_symbol()
end

Expand Down

0 comments on commit 7c6c3de

Please sign in to comment.