From cb443b79778b42bbc4b88ff2a87e4bbf377cda07 Mon Sep 17 00:00:00 2001 From: Iblis Lin Date: Mon, 18 Dec 2017 20:16:25 +0800 Subject: [PATCH] ndarray: porting Python's autograd (#274) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Ref: https://github.com/apache/incubator-mxnet/blob/065adb3702c110af7b537799be3ec9c16c27a72b/python/mxnet/autograd.py * API ported * attach_grad * grad * mark_variables * get_symbol * record * pause * train_mode * predict_mode * backward * An example ```julia x = NDArray([1 2; 3 4]) mx.attach_grad!(x) y = mx.record() do mx.square(x) end mx.backward!(y) mx.getgrad(x) # 2×2 Array{Int64,2}: # 2 4 # 6 8 ``` --- NEWS.md | 2 + src/MXNet.jl | 1 + src/autograd.jl | 387 ++++++++++++++++++++++++++++++++++++++ src/base.jl | 2 +- src/ndarray.jl | 4 +- test/unittest/autograd.jl | 386 +++++++++++++++++++++++++++++++++++++ test/unittest/ndarray.jl | 2 +- 7 files changed, 780 insertions(+), 4 deletions(-) create mode 100644 src/autograd.jl create mode 100644 test/unittest/autograd.jl diff --git a/NEWS.md b/NEWS.md index 70d8626f7..4540cba50 100644 --- a/NEWS.md +++ b/NEWS.md @@ -75,6 +75,8 @@ ### `NDArray` +* A port of Python's `autograd` for `NDArray` (#274) + * `size(x, dims...)` is supported now. (#TBD) ```julia diff --git a/src/MXNet.jl b/src/MXNet.jl index 3583c140b..352d20aad 100644 --- a/src/MXNet.jl +++ b/src/MXNet.jl @@ -105,6 +105,7 @@ include("broadcast.jl") include("ndarray.jl") include("random.jl") +include("autograd.jl") include("name.jl") include("symbolic-node.jl") diff --git a/src/autograd.jl b/src/autograd.jl new file mode 100644 index 000000000..4584decb0 --- /dev/null +++ b/src/autograd.jl @@ -0,0 +1,387 @@ +# Autograd for NDArray +# this is a port of Python's autograd module +# https://github.com/apache/incubator-mxnet/blob/master/python/mxnet/autograd.py + +############################################################################### +# Private util functions +############################################################################### + +""" + _set_recording(state::Bool)::Bool + +Set status to recording/not recording. When recording, graph will be constructed +for gradient computation. + +## Parameters + +* `state::Bool` + +## Returns + +Previous state before this set +""" +function _set_recording(state::Bool)::Bool + prev = Ref{Cint}(C_NULL) + @mxcall(:MXAutogradSetIsRecording, (Cint, Ref{Cint}), state, prev) + prev[] +end + +_set_recording(::Void) = nothing + +""" +Set status to training/predicting. +For example, Dropout will drop inputs randomly when +`train_mode = true` while simply passing through if `train_mode = false`. + +## Parameters +* `train_mode::Bool` + +## Returns + +Previous state before this set. +""" +function _set_training(train_mode::Bool)::Bool + prev = Ref{Cint}(C_NULL) + @mxcall(:MXAutogradSetIsTraining, (Cint, Ref{Cint}), train_mode, prev) + prev[] +end + +_set_training(::Void) = nothing + +############################################################################### +# Public API +############################################################################### + +""" + is_recording()::Bool + +Get status on recording/not recording. +""" +function is_recording()::Bool + state = Ref{Cint}(C_NULL) + @mxcall(:MXAutogradIsRecording, (Ref{Cint},), state) + state[] +end + +""" + is_training()::Bool + +Get status on recording/not recording. +""" +function is_training()::Bool + state = Ref{Cint}(C_NULL) + @mxcall(:MXAutogradIsTraining, (Ref{Cint},), state) + state[] +end + +@inline function _record(f, is_record::Union{Void,Bool}, train_mode::Union{Void,Bool}) + # Port from Python's `_RecordingStateScope` context manager + # __enter__ + prev_is_record = _set_recording(is_record) + prev_train_mode = _set_training(train_mode) + + try + f() + finally + # __exit__ + if is_record != nothing && prev_is_record != is_record + _set_recording(prev_is_record) + end + if train_mode != nothing && prev_train_mode != train_mode + _set_recording(prev_train_mode) + end + end +end + +""" + record(f, train_mode = true) + record(translates = true) do + ... + end + +Returns an autograd recording scope context to be used in `do` block +and captures code that needs gradients to be calculated. + +Parameter `train_mode::Bool` controls whether the forward pass is in training +or predicting mode. +This controls the behavior of some layers such as `Dropout`, `BatchNorm`. + +!!! note + When forwarding with `train_mode = false`, the corresponding backward + should also use `train_mode = false`, otherwise gradient is undefined. + +```julia +x = mx.NDArray([1 2; 3 4]) +∇ = mx.attach_grad!(x) +y = mx.record() do + 2x +end +mx.backward!(y) + +julia> ∇ +2×2 mx.NDArray{Int64,2} @ CPU0: + 2 2 + 2 2 +``` +""" +record(f, train_mode::Bool = true) = _record(f, true, train_mode) + +""" + pause(f, train_mode = false) + pause(train_mode = false) do + ... + end + +Create a scope context for codes that do not need gradients to be calculated. + +```julia +record() do + ... + pause() do + # testing, IO, gradient updates... + end +end +``` +""" +pause(f, train_mode::Bool = false) = _record(f, false, train_mode) + +""" + train_mode(f) + train_mode() do + ... + end + +Create a scope context in which forward pass behavior is set to training mode, +without changing the recording states. + +```julia +y = model(x) +train_mode() do + z = mx.Dropout(y) + ... +end +``` +""" +train_mode(f) = _record(f, nothing, true) + +""" + predict_mode(f) + predict_mode() do + ... + end + +Create a scope context in which forward pass behavior is set to inference mode, +without changing the recording states. + +```julia +record() do + y = model(x) + predict_mode() do + y = sampling(y) + end +end +``` +""" +predict_mode(f) = _record(f, nothing, false) + +""" + backward!(head, head_grad; retain_graph = false, train_mode = true) + backward!(heads, head_grads; retain_graph = false, train_mode = true) + +Compute the gradients of heads w.r.t previously marked variables. + +## Parameters + +- `head::NDArray`: output NDArray + +- `head_grad::NDArray` or `Void`: gradient coefficient with respect to head. + +- `heads::Vector{NDArray}`: a list of output NDArray + +- `head_grads::Vector`: a list of gradient coefficient with respect ot heads. + the element should be `NDArray` or `Void` + +- `retain_graph::Bool`: whether to keep the graph after backward. e.g: + If you want to differentiate the same graph twice, + you need to pass `retain_graph=true`. + +- `train_mode::Bool`: whether to do backward for training or predicting. +""" +backward!(head::NDArray, head_grad::NDArray; kws...) = + backward!([head], [head_grad]; kws...) + +backward!(head::NDArray, head_grad::Void = nothing; kws...) = + backward!([head], head_grad; kws...) + +function backward!(heads::VecOfNDArray, head_grad::Void; + retain_graph::Bool = false, train_mode::Bool = true) + @mxcall( + :MXAutogradBackwardEx, + (MX_uint, + Ptr{MX_handle}, + Ptr{MX_handle}, + MX_uint, + Ptr{MX_handle}, + Cint, + Cint, + Cint, + Ptr{MX_handle}, + Ptr{MX_handle}), + length(heads), + map(x -> x.handle, heads), + C_NULL, + 0, + C_NULL, + retain_graph, + false, # create_graph + train_mode, + C_NULL, + C_NULL) +end + +function backward!(heads::VecOfNDArray, head_grads::Vector; + retain_graph::Bool = false, train_mode::Bool = true) + output_handles = map(x -> x.handle, heads) + ograd_handles = map(head_grads) do x + if x isa NDArray + x.handle + elseif x isa Void + MX_handle(C_NULL) + else + throw(ArgumentError("element of head_grads should be NDArray or Void")) + end + end + @assert length(output_handles) == length(ograd_handles) + @mxcall( + :MXAutogradBackwardEx, + (MX_uint, + Ptr{MX_handle}, + Ptr{MX_handle}, + MX_uint, + Ptr{MX_handle}, + Cint, + Cint, + Cint, + Ptr{MX_handle}, + Ptr{MX_handle}), + length(output_handles), + output_handles, + ograd_handles, + 0, + C_NULL, + retain_graph, + false, # create_graph + train_mode, + C_NULL, + C_NULL) +end + +""" + getgrad(arr::NDArray) + +Returns the gradient buffer attached to this `NDArray`. +If the gradient buffer isn't attached yet, return `nothing`. +""" +function getgrad(arr::NDArray) + out = Ref{MX_handle}(C_NULL) + @mxcall(:MXNDArrayGetGrad, (MX_handle, Ref{MX_handle}), arr.handle, out) + (out[] == C_NULL) ? nothing : NDArray(MX_NDArrayHandle(out[])) +end + +""" + attach_grad!(x::NDArray, grad_req::Symbol = :write) + +Attach a gradient buffer to this `NDArray`, +so that [`backward!`](@ref) can compute gradient with respect to it. + +## Parameters + +- `x::NDArray` +- `grad_req::Symbol` (default is `:write`) + +## Return + +The attached gradient buffer + +## See also + +- [`getgrad`](@ref) +""" +function attach_grad!(x::NDArray, grad_req::Symbol = :write) + # TODO: support storage type (stype in Python) + # TODO: make sure it works with gpu array + grad = zeros_like(x) + _mark_variables!([x], [grad], grad_req) + grad +end + +""" + mark_variables!(var, grad, grad_req) + mark_variables!(vars, grads, grad_reqs) + +Mark `NDArrays` as variables to compute gradient for autograd. + +## Parameters + +- `var::NDArray` +- `grad::NDArray` +- `grad_req::Symbol`: `:nop`, `:write`, `:inplace` or `:add` +- `vars::Vector{NDArray}` +- `grads::Vector{NDArray}` +- `grad_req::Vector{Symbol}` +""" +mark_variables!(var::NDArray, grad::NDArray, grad_reqs::Symbol = :write) = + _mark_variables!([var], [grad], grad_reqs) + +mark_variables!(var::VecOfNDArray, grads::VecOfNDArray, grad_reqs = :write) = + _mark_variables!(var, grads, grad_reqs) + +@inline function _getgrad_req(x::Symbol)::GRAD_REQ + val = get(grad_req_map, x, false) + if val == false + throw(ArgumentError("invalid grad_reqs $x")) + end + val +end + +@inline _getgrad_reqs(x::Symbol, n::Int) = + map((_) -> MX_uint(_getgrad_req(x)), Base.OneTo(n)) + +@inline function _getgrad_reqs(xs::Vector{Symbol}, n::Int) + if length(xs) != n + throw(ArgumentError("number of variables and grad_reqs not matched")) + end + map(MX_uint ∘ _getgrad_req, xs) +end + +@inline function _mark_variables!(vars::VecOfNDArray, grads::VecOfNDArray, + grad_reqs = :write) + n = length(vars) + if n != length(grads) + throw(ArgumentError("number of variables and gradients not matched")) + end + + var_hdls = map(x -> x.handle, vars) + grad_hdls = map(x -> x.handle, grads) + grad_reqs = _getgrad_reqs(grad_reqs, n) + + @mxcall(:MXAutogradMarkVariables, + (MX_uint, Ref{MX_handle}, Ptr{MX_uint}, Ref{MX_handle}), + length(vars), var_hdls, grad_reqs, grad_hdls) +end + +""" + symbol(x::NDArray) + +Retrieve recorded computation history as `SymbolicNode`, + where `x` is a `NDArray` representing the head of computation graph. + """ +function symbol(x::NDArray) + ref = Ref{MX_handle}(C_NULL) + @mxcall(:MXAutogradGetSymbol, (MX_handle, Ref{MX_handle}), x, ref) + SymbolicNode(MX_SymbolHandle(ref[])) +end + +############################################################################### +# TODO: User-defined differentiable function +############################################################################### diff --git a/src/base.jl b/src/base.jl index 8f14d44c6..b8f73eb4e 100644 --- a/src/base.jl +++ b/src/base.jl @@ -20,7 +20,7 @@ const char_pp = Ptr{char_p} ################################################################################ # OpReqType in include/mxnet/op_attr_types.h @enum GRAD_REQ GRAD_NOP=0 GRAD_WRITE=1 GRAD_INPLACE=2 GRAD_ADD=3 -const grad_req_map = Dict{Symbol, GRAD_REQ}( +const grad_req_map = Dict{Symbol,GRAD_REQ}( :nop => GRAD_NOP, # no operation, do not write anything :write => GRAD_WRITE, # write gradient to provided space :inplace => GRAD_INPLACE, # perform an inplace write diff --git a/src/ndarray.jl b/src/ndarray.jl index 139e40ef1..de5d6ba4f 100644 --- a/src/ndarray.jl +++ b/src/ndarray.jl @@ -306,10 +306,10 @@ dimension. For example, given an `NDArray` of shape (2,3,4), `slice(array, 2:3)` a `NDArray` of shape (2,3,2), sharing the data with the original array. This operation is used in data parallelization to split mini-batch into sub-batches for different devices. """ -function slice(arr :: NDArray, ::Colon) +function slice(arr::NDArray, ::Colon) arr end -function slice(arr :: NDArray, slice::UnitRange{Int}) +function slice(arr::NDArray, slice::UnitRange{Int}) dim1 = size(arr)[end] @assert(1 <= slice.start <= slice.stop <= dim1) if slice.start == 1 && slice.stop == dim1 diff --git a/test/unittest/autograd.jl b/test/unittest/autograd.jl new file mode 100644 index 000000000..12c1022bd --- /dev/null +++ b/test/unittest/autograd.jl @@ -0,0 +1,386 @@ +module TestAutoGrad + +using Base.Test + +using MXNet + + +function checkgradient(f, x, y, ∇) + ∇x = mx.attach_grad!(x) + y′ = mx.record(f) + @test copy(y′) ≈ y + @test copy(∇x) |> sum == 0 + mx.backward!(y′) + @test copy(mx.getgrad(x)) ≈ ∇ +end # function checkgradient + + +function test_getgrad() + info("AutoGrad::getgrad") + + info("AutoGrad::getgrad::unattached") + @test nothing == mx.getgrad(mx.zeros(10)) + + info("AutoGrad::getgrad::attached") + x = mx.NDArray([1 2; 3 4]) + grad = mx.attach_grad!(x) + @test eltype(grad) ≡ Int + @test copy(grad) == [0 0; 0 0] + + grad[:] = 42 + @test copy(mx.getgrad(x)) == [42 42; 42 42] +end + + +function test_mark_variables!() + info("AutoGrad::mark_variables!") + x = mx.zeros(4) + ẋ = mx.zeros(4) + y = mx.zeros(4) + ẏ = mx.zeros(4) + mx.mark_variables!([x, y], [ẋ, ẏ], [:nop, :nop]) + ẋ[:] = 42 + ẏ[:] = 24 + + @test copy(mx.getgrad(x)) == [42, 42, 42, 42] + @test copy(mx.getgrad(y)) == [24, 24, 24, 24] + + info("AutoGrad::mark_variables!::invalid grad_reqs") + x = mx.zeros(4) + y = mx.zeros(4) + @test_throws ArgumentError mx.mark_variables!(x, y, :magic) + @test_throws ArgumentError mx.mark_variables!([x], [y], [:magic]) + + info("AutoGrad::mark_variables!::args length mismatch") + x = mx.zeros(4) + y = mx.zeros(4) + z = mx.zeros(4) + @test_throws ArgumentError mx.mark_variables!([x], [y, z]) + @test_throws ArgumentError mx.mark_variables!([x], [y], [:write, :nop]) +end + + +function test_record() + let x = mx.NDArray([1 2; 3 4]) + info("AutoGrad::record::backward!") + + y = [1 4; 9 16] + ∇ = [2 4; 6 8] # gradient is 2x + checkgradient(x, y, ∇) do + mx.square(x) + end + end + + let x = mx.NDArray([1 2; 3 4]) + info("AutoGrad::record::symbol") + + mx.attach_grad!(x) + y = mx.record() do + mx.square(x) + end + + @test copy(y) == [1 4; 9 16] + + @test isa(mx.symbol(y), mx.SymbolicNode) + end + + let x = mx.NDArray([1 2; 3 4]) + info("AutoGrad::record::backward!(retain_graph=true)") + + mx.attach_grad!(x) + y = mx.record() do + mx.square(x) + end + + @test copy(y) == [1 4; 9 16] + + mx.backward!(y, retain_graph=true) + # gradient is 2x + @test copy(mx.getgrad(x)) == [2 4; 6 8] + + @test isa(mx.symbol(y), mx.SymbolicNode) + end + + mx._record(nothing, nothing) do # no error with edage case + @test true + end +end # function test_record + + +function test_is_recording() + info("AutoGrad::is_recording") + mx.record() do + @test mx.is_recording() + end +end # function test_is_recording + + +function test_is_training() + info("AutoGrad::is_training") + mx.record() do + @test mx.is_training() + end + + mx.record(false) do + @test !mx.is_training() + end +end # function test_is_training + + +function test_pause() + info("AutoGrad::pause") + let x = mx.NDArray([1 2; 3 4]) + ∇ = mx.attach_grad!(x) + y = mx.record() do + y = mx.square(x) + mx.pause() do + z = mx.square(y) + @test copy(z) == [1 16; 81 256] + end + y + end + + @test copy(y) == [1 4; 9 16] + + mx.backward!(y) + @test copy(∇) == [2 4; 6 8] + end +end # function test_pause + + +function test_train_mode() + info("AutoGrad::train_mode") + let x = mx.NDArray(Float32[1 2; 3 4]) + y = mx.train_mode() do + mx.Dropout(x, p = 1) + end + + @test all(isnan.(copy(y))) + end +end # function test_train_mode + + +function test_predict_mode() + info("AutoGrad::predict_mode") + let x = mx.NDArray(Float32[1 2; 3 4]) + y = mx.predict_mode() do + mx.Dropout(x, p = 1) + end + + @test copy(y) ≈ Float32[1 2; 3 4] + end +end # function test_train_mode + + +function test_backward!() + info("AutoGrad::backward!::with head_grad") + let x = mx.NDArray(Float32[1 2; 3 4]), A = Float32[.2 .4; 0 .1] + ∇ = mx.attach_grad!(x) + y = mx.record() do + mx.square(x) + end + mx.backward!(y, mx.NDArray(A)) + @test copy(∇) ≈ [2 4; 6 8] .* A + end + + info("AutoGrad::backward!::with head_grads") + let x = mx.NDArray(Float32[1 2; 3 4]) + ∇ = mx.attach_grad!(x) + mx.record() do + x′ = mx.square(x) + y = mx.square(x) + z = mx.square(x) .+ 42 + mx.backward!([x′, y, z], [nothing, + mx.NDArray(Float32[.01 .01; 1 1]), + mx.NDArray(Float32[1 1; .1 .1])]) + end + ans = [4.02 8.04 + 12.6 16.8] + @test copy(∇) ≈ ans + end + + info("AutoGrad::backward!::ArgumentError") + let x = mx.NDArray([42]) + @test_throws ArgumentError mx.backward!([x], [24]) + end +end # function test_backward! + + +function test_symbol() + info("AutoGrad::symbol") + + let x = mx.zeros(4) + mx.attach_grad!(x) + @test isa(mx.symbol(x), mx.SymbolicNode) + end +end + + +function test_add() + info("AutoGrad::add") + + info("AutoGrad::add::x") + let x = mx.NDArray([1 2; 3 4]) + y = [1 2; 3 4] + ∇ = [1 1; 1 1] # gradient is 1 + checkgradient(x, y, ∇) do + x + end + end + + info("AutoGrad::add::+x") + let x = mx.NDArray([1 2; 3 4]) + y = [1 2; 3 4] + ∇ = [1 1; 1 1] # gradient is 1 + checkgradient(x, y, ∇) do + +x + end + end + + info("AutoGrad::add::x .+ 42") + let x = mx.NDArray([1 2; 3 4]) + y = [43 44; 45 46] + ∇ = [1 1; 1 1] # gradient is 1 + checkgradient(x, y, ∇) do + x .+ 42 + end + end + + info("AutoGrad::add::42 .+ x") + let x = mx.NDArray([1 2; 3 4]) + y = [43 44; 45 46] + ∇ = [1 1; 1 1] + checkgradient(x, y, ∇) do + 42 .+ x + end + end + + # TODO: info("AutoGrad::add::x .+ y") +end # function test_add + + +function test_sub() + info("AutoGrad::sub") + + info("AutoGrad::sub::-x") + let x = mx.NDArray([1 2; 3 4]) + y = [-1 -2; -3 -4] + ∇ = [-1 -1; -1 -1] # gradient is -1 + checkgradient(x, y, ∇) do + -x + end + end + + info("AutoGrad::sub::x .- 42") + let x = mx.NDArray([1 2; 3 4]) + y = [-41 -40; -39 -38] + ∇ = [1 1; 1 1] + checkgradient(x, y, ∇) do + x .- 42 + end + end + + info("AutoGrad::sub::42 .- x") + let x = mx.NDArray([1 2; 3 4]) + y = [41 40; 39 38] + ∇ = -[1 1; 1 1] + checkgradient(x, y, ∇) do + 42 .- x + end + end + + # TODO: info("AutoGrad::add::x .- y") +end # function test_sub + + +function test_mul() + info("AutoGrad::mul") + + info("AutoGrad::mul::2x .* x") + let x = mx.NDArray([1 2; 3 4]) + y = [2 8; 18 32] + ∇ = [4 8; 12 16] # 4x + checkgradient(x, y, ∇) do + 2x .* x + end + end + + info("AutoGrad::mul::x * 2 .* x") + let x = mx.NDArray([1 2; 3 4]) + y = [2 8; 18 32] + ∇ = [4 8; 12 16] # 4x + checkgradient(x, y, ∇) do + x * 2 .* x + end + end +end + + +function test_div() + info("AutoGrad::div") + + info("AutoGrad::div::x ./ 2") + let x = mx.NDArray(Float32[1 2; 3 4]) + y = Float32[.5 1; 1.5 2] + ∇ = [.5 .5; .5 .5] + checkgradient(x, y, ∇) do + x ./ 2 + end + end + + info("AutoGrad::rdiv::2 ./ x") + let A = Float32[1 2; 3 4], x = mx.NDArray(A) + y = 2 ./ A + ∇ = @. -2 / A^2 # -2 / x² + checkgradient(x, y, ∇) do + 2 ./ x + end + end +end # function test_div + + +function test_power() + info("AutoGrad::power") + + info("AutoGrad::power::x.^3") + let A = Float32[1 2; 3 4] + x = mx.NDArray(A) + y = A.^3 + ∇ = 3(A.^2) + checkgradient(x, y, ∇) do + x.^3 + end + end + + info("AutoGrad::power::x.^.5") + let A = Float32[1 2; 3 4] + x = mx.NDArray(A) + y = A.^.5 + ∇ = .5(A.^-.5) + checkgradient(x, y, ∇) do + x.^.5 + end + end +end + + +@testset "AutoGrad Test" begin + test_getgrad() + test_mark_variables!() + test_record() + test_is_recording() + test_is_training() + test_pause() + test_train_mode() + test_predict_mode() + test_backward!() + test_symbol() + test_add() + test_sub() + test_mul() + test_div() + test_power() +end + + +end # model TestAutoGrad diff --git a/test/unittest/ndarray.jl b/test/unittest/ndarray.jl index a24126cf1..ef4fb1f23 100644 --- a/test/unittest/ndarray.jl +++ b/test/unittest/ndarray.jl @@ -502,7 +502,7 @@ function test_rdiv() @test copy(x) ≈ y end - info("NDArray:rdiv::type convert") + info("NDArray::rdiv::type convert") let x = mx.NDArray([1, 2, 3]) y = 5.5 ./ x @test eltype(y) == Int # this differs from julia