Skip to content

Commit

Permalink
autograd: initial port of Python's autograd
Browse files Browse the repository at this point in the history
Ref: https://github.com/apache/incubator-mxnet/blob/065adb3702c110af7b537799be3ec9c16c27a72b/python/mxnet/autograd.py

* API ported
    * attach_grad
    * grad
    * mark_variables
    * get_symbol
    * record
    * pause
    * train_mode
    * predict_mode
    * backward

* An example

```julia
x = mx.NDArray([1 2; 3 4])
mx.attach_grad(x)

y = mx.record() do
  mx.square(x)
end

mx.backward(y)

copy(mx.grad(x))
 # 2×2 Array{Int64,2}:
 # 2  4
 # 6  8
```
  • Loading branch information
iblislin committed Sep 25, 2017
1 parent df13ddd commit 55edb4f
Show file tree
Hide file tree
Showing 3 changed files with 382 additions and 0 deletions.
2 changes: 2 additions & 0 deletions src/MXNet.jl
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@ include("name.jl")
include("symbolic-node.jl")
include("executor.jl")

include("autograd.jl")

include("metric.jl")
include("optimizer.jl")
include("initializer.jl")
Expand Down
351 changes: 351 additions & 0 deletions src/autograd.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,351 @@
# Autograd for NDArray
# this is a port of Python's autograd module
# https://github.com/apache/incubator-mxnet/blob/master/python/mxnet/autograd.py

###############################################################################
# Private util functions
###############################################################################

"""
_set_recording(state::Bool)::Bool
Set status to recording/not recording. When recording, graph will be constructed
for gradient computation.
## Parameters
* `state::Bool`
## Returns
* previous state before this set.
"""
function _set_recording(state::Bool)::Bool
prev = Ref{Cint}(C_NULL)
@mxcall(:MXAutogradSetIsRecording, (Cint, Ref{Cint}), state, prev)
prev[]
end

"""
Set status to training/predicting.
For example, Dropout will drop inputs randomly when
`train_mode=true` while simply passing through if `train_mode=false`.
## Parameters
* `train_mode::Bool`
## Returns
Previous state before this set.
"""
function _set_training(train_mode::Bool)::Bool
prev = Ref{Cint}(C_NULL)
@mxcall(:MXAutogradSetIsTraining, (Cint, Ref{Cint}), train_mode, prev)
prev[]
end

"""
Get status on recording/not recording.
## Returns
Current state of recording.
"""
function _is_recording()::Bool
state = Ref{Cint}(C_NULL)
@mxcall(:MXAutogradIsRecording, (Ref{Cint},), state)
state[]
end

"""
Get status on recording/not recording.
## Returns
Current state of recording.
"""
function _is_training()::Bool
state = Ref{Cint}(C_NULL)
@mxcall(:MXAutogradIsTraining, (Ref{Cint},), state)
state[]
end

###############################################################################
# Public API
###############################################################################

@inline function _record(f::Function, is_record::Union{Void, Bool},
train_mode::Union{Void, Bool})
# Port from Python's `_RecordingStateScope` context manager
# __enter__
prev_is_record = _set_recording(is_record)
prev_train_mode = _set_training(train_mode)
#= println("$is_record $train_mode $prev_is_record $prev_train_mode") =#

try
f()
finally
# __exit__
if is_record != nothing && prev_is_record != is_record
_set_recording(prev_is_record)
end
if train_mode != nothing && prev_train_mode != train_mode
_set_recording(prev_train_mode)
end
end
end

"""
record(f::Function)
record() do
...
end
Returns an autograd recording scope context to be used in `do` block
and captures code that needs gradients to be calculated.
.. note:: When forwarding with `train_mode=false`, the corresponding backward
should also use `train_mode=false`, otherwise gradient is undefined.
## Example
```julia
# TBD
```
## Parameters
* `train_mode::Bool` (default is `true`)
Whether the forward pass is in training or predicting mode.
This controls the behavior of some layers such as `Dropout`, `BatchNorm`.
"""
record(f::Function, train_mode::Bool=true) = _record(f, true, train_mode)

"""
pause(f::Function)
pause() do
...
end
Returns a scope context to be used in 'with' statement for codes that do not
need gradients to be calculated.
## Example (TBD)
```julia
record() do
...
pause() do
# testing, IO, gradient updates...
end
end
```
## Parameters
* `train_mode::Bool` (default is `false`)
Whether to do forward for training or predicting.
"""
pause(f::Function, train_mode::Bool=false) = _record(f, false, train_mode)

"""
train_mode(f::Function)
train_mode() do
...
end
Returns a scope context to be used in 'with' statement in which forward pass
behavior is set to training mode, without changing the recording states.
## Example
```julia
y = model(x)
train_mode() do
y = dropout(y)
end
```
"""
train_mode(f::Function) = _record(f, nothing, true)

"""
predict_mode(f::Function)
predict_mode() do
...
end
Returns a scope context to be used in 'with' statement in which forward pass
behavior is set to inference mode, without changing the recording states.
## Example
```julia
record() do
y = model(x)
predict_mode() do
y = sampling(y)
end
end
```
"""
predict_mode(f::Function) = _record(f, nothing, false)

backward(head::NDArray, head_grad::NDArray; kwargs...) =
backward([head], [head_grad]; kwargs...)

backward(head::NDArray, head_grad::Void=nothing; kwargs...) =
backward([head], head_grad; kwargs...)

function backward(heads::Vector{NDArray}, head_grads=Union{Vector, Void};
retain_graph::Bool=false, train_mode::Bool=true)
output_handles = map(arr -> arr.handle, heads)

if head_grads == nothing
@mxcall(:MXAutogradBackwardEx,
(MX_uint, Ptr{MX_handle}, Ptr{MX_handle}, Cint, Cint),
length(output_handles), output_handles, C_NULL,
retain_graph, train_mode)
return
end

ograd_handles = map(head_grads) do arr
if isa(arr, NDArray)
arr.handle
elseif isa(arr, Void)
MX_handle(C_NULL)
else
throw(ArgumentError("element of head_grads should be NDArray or Void"))
end
end
@assert length(output_handles) == length(ograd_handles)
@mxcall(:MXAutogradBackwardEx,
(MX_uint, Ptr{MX_handle}, Ptr{MX_handle}, Cint, Cint),
length(output_handles), output_handles, ograd_handles,
retain_graph, train_mode)
end

"""
backward(head, head_grad; retain_graph=false, train_mode=true)
backward(heads, head_grads; retain_graph=false, train_mode=true)
Compute the gradients of heads w.r.t previously marked variables.
## Parameters
- `head::NDArray`: output NDArray
- `heads::Vector{NDArray}`: a list of output NDArray
- `head_grad::NDArray` or `Void`: gradient with respect to head.
- `head_grads::Vector`: a list of gradient with respect ot heads.
the element should be `NDArray` or `Void`
retain_graph: whether to keep the graph after backward. e.g:
If you want to differentiate the same graph twice, you need to pass retain_graph=true.
train_mode: bool, optional
Whether to do backward for training or predicting.
"""
backward

"""
grad(arr::NDArray)
Returns gradient buffer attached to this `NDArray`.
If the gradient buffer isn't attached yet, return `nothing`.
"""
function grad(arr::NDArray)
out = Ref{mx.MX_handle}(C_NULL)
@mxcall(:MXNDArrayGetGrad, (MX_handle, Ref{MX_handle}), arr.handle, out)
(out[] == C_NULL) ? nothing: NDArray(MX_NDArrayHandle(out[]))
end

"""
attach_grad(arr::NDArray, grad_req::Symbol=:write)
Attach a gradient buffer to this `NDArray`, so that [`backward`](@ref)
can compute gradient with respect to it.
## Parameters
- `arr::NDArray`
- `grad_req::Symbol` (default is `:write`)
## Return
The attached gradient buffer
## See also
- [`grad`](@ref)
"""
function attach_grad(arr::NDArray, grad_req::Symbol=:write)
# TODO: support storage type (stype in Python)
# TODO: make sure it works with gpu array
grad = zeros_like(arr)
_mark_variables([arr], [grad], grad_req)
grad
end

"""
mark_variables(vars, grads)
Mark NDArrays as variables to compute gradient for autograd.
## Parameters
variables: NDArray or list of NDArray
gradients: NDArray or list of NDArray
grad_reqs: str or list of str
"""
mark_variables(var::NDArray, grad::NDArray, grad_reqs::Symbol=:write) =
_mark_variables([var], [grad], grad_reqs)

mark_variables(var::Vector{NDArray}, grads::Vector{NDArray}, grad_reqs=:write) =
_mark_variables(var, grads, grad_reqs)

@inline function _mark_variables(vars::Vector{NDArray}, grads::Vector{NDArray},
grad_reqs::Union{Vector{Symbol}, Symbol}=:write)
# TODO: leverage grad reqs map from #283
if length(vars) != length(grads)
throw(ArgumentError("number of variables and gradients not matched"))
end

var_hdls = map(arr -> arr.handle, vars)
grad_hdls = map(arr -> arr.handle, grads)

if isa(grad_reqs, Symbol)
grad_reqs = MX_uint[GRAD_WRITE for i 1:length(vars)] # FIXME
else
if length(vars) != length(grad_reqs)
throw(ArgumentError("number of variables and gradients not matched"))
end
grad_reqs = MX_uint[GRAD_WRITE for i 1:length(vars)] # FIXME
end

@mxcall(:MXAutogradMarkVariables,
(MX_uint, Ref{MX_handle}, Ptr{MX_uint}, Ref{MX_handle}),
length(vars), var_hdls, MX_uint[GRAD_WRITE],
grad_hdls)
end

"""
get_symbol(arr)
Retrieve recorded computation history as `SymbolicNode`.
## Parameters
* `x::NDArray`: Array representing the head of computation graph.
## Returns
The retrieved `Symbol`.
"""
function get_symbol(arr::NDArray)
ref = Ref{MX_handle}(C_NULL)
@mxcall(:MXAutogradGetSymbol, (MX_handle, Ref{MX_handle}), arr, ref)
SymbolicNode(MX_SymbolHandle(ref[]))
end

###############################################################################
# TODO: User-defined differentiable function
###############################################################################
29 changes: 29 additions & 0 deletions test/unittest/autograd.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
module TestAutoGrad

using Base.Test

using MXNet


function test_grad()
info("AutoGrad::grad")

info("AutoGrad::grad::unattached")
@test nothing == mx.grad(mx.zeros(10))

info("AutoGrad::grad::attached")
x = mx.NDArray([1 2; 3 4])
grad = mx.attach_grad(x)
@test eltype(grad) Int
@test copy(grad) == [0 0; 0 0]

grad[:] = 42
@test copy(mx.grad(x)) == [42 42; 42 42]
end


@testset "AutoGrad Test" begin
test_grad()
end

end # model TestAutoGrad

0 comments on commit 55edb4f

Please sign in to comment.