From 2cd4ca612b29ab161dddd1ba99d9b163508133bc Mon Sep 17 00:00:00 2001 From: Iblis Lin Date: Mon, 25 Sep 2017 21:30:23 +0800 Subject: [PATCH] autograd: adop grad_req_map for mark_variables - add test cases for mark_variables --- src/autograd.jl | 25 ++++++++++++++++++------- test/unittest/autograd.jl | 29 +++++++++++++++++++++++++++++ 2 files changed, 47 insertions(+), 7 deletions(-) diff --git a/src/autograd.jl b/src/autograd.jl index fdc2700ef..47b4f14ce 100644 --- a/src/autograd.jl +++ b/src/autograd.jl @@ -18,7 +18,7 @@ for gradient computation. ## Returns -* previous state before this set. +Previous state before this set """ function _set_recording(state::Bool)::Bool prev = Ref{Cint}(C_NULL) @@ -80,7 +80,6 @@ end # __enter__ prev_is_record = _set_recording(is_record) prev_train_mode = _set_training(train_mode) - #= println("$is_record $train_mode $prev_is_record $prev_train_mode") =# try f() @@ -304,27 +303,39 @@ mark_variables(var::Vector{NDArray}, grads::Vector{NDArray}, grad_reqs=:write) = @inline function _mark_variables(vars::Vector{NDArray}, grads::Vector{NDArray}, grad_reqs::Union{Vector{Symbol}, Symbol}=:write) - # TODO: leverage grad reqs map from #283 if length(vars) != length(grads) throw(ArgumentError("number of variables and gradients not matched")) end + var_hdls = map(arr -> arr.handle, vars) grad_hdls = map(arr -> arr.handle, grads) if isa(grad_reqs, Symbol) - grad_reqs = MX_uint[GRAD_WRITE for i ∈ 1:length(vars)] # FIXME + val = get(grad_req_map, grad_reqs, false) + if val == false + throw(ArgumentError("invalid grad_reqs $grad_reqs")) + end + + grad_reqs = MX_uint[val for i ∈ 1:length(vars)] else if length(vars) != length(grad_reqs) throw(ArgumentError("number of variables and gradients not matched")) end - grad_reqs = MX_uint[GRAD_WRITE for i ∈ 1:length(vars)] # FIXME + + grad_reqs = map(grad_reqs) do k + val = get(grad_req_map, k, false) + if val == false + throw(ArgumentError("invalid grad_reqs $k")) + end + + MX_uint(val) + end end @mxcall(:MXAutogradMarkVariables, (MX_uint, Ref{MX_handle}, Ptr{MX_uint}, Ref{MX_handle}), - length(vars), var_hdls, MX_uint[GRAD_WRITE], - grad_hdls) + length(vars), var_hdls, grad_reqs, grad_hdls) end """ diff --git a/test/unittest/autograd.jl b/test/unittest/autograd.jl index 9de90b6e9..2904c975b 100644 --- a/test/unittest/autograd.jl +++ b/test/unittest/autograd.jl @@ -22,8 +22,37 @@ function test_grad() end +function test_mark_variables() + info("AutoGrad::mark_variables") + x = mx.zeros(4) + ẋ = mx.zeros(4) + y = mx.zeros(4) + ẏ = mx.zeros(4) + mx.mark_variables([x, y], [ẋ, ẏ], [:nop, :nop]) + ẋ[:] = 42 + ẏ[:] = 24 + + @test copy(mx.grad(x)) == [42, 42, 42, 42] + @test copy(mx.grad(y)) == [24, 24, 24, 24] + + info("AutoGrad::mark_variables::invalid grad_reqs") + x = mx.zeros(4) + y = mx.zeros(4) + @test_throws ArgumentError mx.mark_variables(x, y, :magic) + @test_throws ArgumentError mx.mark_variables([x], [y], [:magic]) + + info("AutoGrad::mark_variables::args length mismatch") + x = mx.zeros(4) + y = mx.zeros(4) + z = mx.zeros(4) + @test_throws ArgumentError mx.mark_variables([x], [y, z]) +end + + @testset "AutoGrad Test" begin test_grad() + test_mark_variables() end + end # model TestAutoGrad