-
Notifications
You must be signed in to change notification settings - Fork 70
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
autograd: initial port of Python's autograd
Ref: https://github.com/apache/incubator-mxnet/blob/065adb3702c110af7b537799be3ec9c16c27a72b/python/mxnet/autograd.py * API ported * attach_grad * grad * mark_variables * get_symbol * record * pause * train_mode * predict_mode * backward * An example ```julia x = mx.NDArray([1 2; 3 4]) mx.attach_grad(x) y = mx.record() do mx.square(x) end mx.backward(y) copy(mx.grad(x)) # 2×2 Array{Int64,2}: # 2 4 # 6 8 ```
- Loading branch information
Showing
3 changed files
with
382 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,351 @@ | ||
# Autograd for NDArray | ||
# this is a port of Python's autograd module | ||
# https://github.com/apache/incubator-mxnet/blob/master/python/mxnet/autograd.py | ||
|
||
############################################################################### | ||
# Private util functions | ||
############################################################################### | ||
|
||
""" | ||
_set_recording(state::Bool)::Bool | ||
Set status to recording/not recording. When recording, graph will be constructed | ||
for gradient computation. | ||
## Parameters | ||
* `state::Bool` | ||
## Returns | ||
* previous state before this set. | ||
""" | ||
function _set_recording(state::Bool)::Bool | ||
prev = Ref{Cint}(C_NULL) | ||
@mxcall(:MXAutogradSetIsRecording, (Cint, Ref{Cint}), state, prev) | ||
prev[] | ||
end | ||
|
||
""" | ||
Set status to training/predicting. | ||
For example, Dropout will drop inputs randomly when | ||
`train_mode=true` while simply passing through if `train_mode=false`. | ||
## Parameters | ||
* `train_mode::Bool` | ||
## Returns | ||
Previous state before this set. | ||
""" | ||
function _set_training(train_mode::Bool)::Bool | ||
prev = Ref{Cint}(C_NULL) | ||
@mxcall(:MXAutogradSetIsTraining, (Cint, Ref{Cint}), train_mode, prev) | ||
prev[] | ||
end | ||
|
||
""" | ||
Get status on recording/not recording. | ||
## Returns | ||
Current state of recording. | ||
""" | ||
function _is_recording()::Bool | ||
state = Ref{Cint}(C_NULL) | ||
@mxcall(:MXAutogradIsRecording, (Ref{Cint},), state) | ||
state[] | ||
end | ||
|
||
""" | ||
Get status on recording/not recording. | ||
## Returns | ||
Current state of recording. | ||
""" | ||
function _is_training()::Bool | ||
state = Ref{Cint}(C_NULL) | ||
@mxcall(:MXAutogradIsTraining, (Ref{Cint},), state) | ||
state[] | ||
end | ||
|
||
############################################################################### | ||
# Public API | ||
############################################################################### | ||
|
||
@inline function _record(f::Function, is_record::Union{Void, Bool}, | ||
train_mode::Union{Void, Bool}) | ||
# Port from Python's `_RecordingStateScope` context manager | ||
# __enter__ | ||
prev_is_record = _set_recording(is_record) | ||
prev_train_mode = _set_training(train_mode) | ||
#= println("$is_record $train_mode $prev_is_record $prev_train_mode") =# | ||
|
||
try | ||
f() | ||
finally | ||
# __exit__ | ||
if is_record != nothing && prev_is_record != is_record | ||
_set_recording(prev_is_record) | ||
end | ||
if train_mode != nothing && prev_train_mode != train_mode | ||
_set_recording(prev_train_mode) | ||
end | ||
end | ||
end | ||
|
||
""" | ||
record(f::Function) | ||
record() do | ||
... | ||
end | ||
Returns an autograd recording scope context to be used in `do` block | ||
and captures code that needs gradients to be calculated. | ||
.. note:: When forwarding with `train_mode=false`, the corresponding backward | ||
should also use `train_mode=false`, otherwise gradient is undefined. | ||
## Example | ||
```julia | ||
# TBD | ||
``` | ||
## Parameters | ||
* `train_mode::Bool` (default is `true`) | ||
Whether the forward pass is in training or predicting mode. | ||
This controls the behavior of some layers such as `Dropout`, `BatchNorm`. | ||
""" | ||
record(f::Function, train_mode::Bool=true) = _record(f, true, train_mode) | ||
|
||
""" | ||
pause(f::Function) | ||
pause() do | ||
... | ||
end | ||
Returns a scope context to be used in 'with' statement for codes that do not | ||
need gradients to be calculated. | ||
## Example (TBD) | ||
```julia | ||
record() do | ||
... | ||
pause() do | ||
# testing, IO, gradient updates... | ||
end | ||
end | ||
``` | ||
## Parameters | ||
* `train_mode::Bool` (default is `false`) | ||
Whether to do forward for training or predicting. | ||
""" | ||
pause(f::Function, train_mode::Bool=false) = _record(f, false, train_mode) | ||
|
||
""" | ||
train_mode(f::Function) | ||
train_mode() do | ||
... | ||
end | ||
Returns a scope context to be used in 'with' statement in which forward pass | ||
behavior is set to training mode, without changing the recording states. | ||
## Example | ||
```julia | ||
y = model(x) | ||
train_mode() do | ||
y = dropout(y) | ||
end | ||
``` | ||
""" | ||
train_mode(f::Function) = _record(f, nothing, true) | ||
|
||
""" | ||
predict_mode(f::Function) | ||
predict_mode() do | ||
... | ||
end | ||
Returns a scope context to be used in 'with' statement in which forward pass | ||
behavior is set to inference mode, without changing the recording states. | ||
## Example | ||
```julia | ||
record() do | ||
y = model(x) | ||
predict_mode() do | ||
y = sampling(y) | ||
end | ||
end | ||
``` | ||
""" | ||
predict_mode(f::Function) = _record(f, nothing, false) | ||
|
||
backward(head::NDArray, head_grad::NDArray; kwargs...) = | ||
backward([head], [head_grad]; kwargs...) | ||
|
||
backward(head::NDArray, head_grad::Void=nothing; kwargs...) = | ||
backward([head], head_grad; kwargs...) | ||
|
||
function backward(heads::Vector{NDArray}, head_grads=Union{Vector, Void}; | ||
retain_graph::Bool=false, train_mode::Bool=true) | ||
output_handles = map(arr -> arr.handle, heads) | ||
|
||
if head_grads == nothing | ||
@mxcall(:MXAutogradBackwardEx, | ||
(MX_uint, Ptr{MX_handle}, Ptr{MX_handle}, Cint, Cint), | ||
length(output_handles), output_handles, C_NULL, | ||
retain_graph, train_mode) | ||
return | ||
end | ||
|
||
ograd_handles = map(head_grads) do arr | ||
if isa(arr, NDArray) | ||
arr.handle | ||
elseif isa(arr, Void) | ||
MX_handle(C_NULL) | ||
else | ||
throw(ArgumentError("element of head_grads should be NDArray or Void")) | ||
end | ||
end | ||
@assert length(output_handles) == length(ograd_handles) | ||
@mxcall(:MXAutogradBackwardEx, | ||
(MX_uint, Ptr{MX_handle}, Ptr{MX_handle}, Cint, Cint), | ||
length(output_handles), output_handles, ograd_handles, | ||
retain_graph, train_mode) | ||
end | ||
|
||
""" | ||
backward(head, head_grad; retain_graph=false, train_mode=true) | ||
backward(heads, head_grads; retain_graph=false, train_mode=true) | ||
Compute the gradients of heads w.r.t previously marked variables. | ||
## Parameters | ||
- `head::NDArray`: output NDArray | ||
- `heads::Vector{NDArray}`: a list of output NDArray | ||
- `head_grad::NDArray` or `Void`: gradient with respect to head. | ||
- `head_grads::Vector`: a list of gradient with respect ot heads. | ||
the element should be `NDArray` or `Void` | ||
retain_graph: whether to keep the graph after backward. e.g: | ||
If you want to differentiate the same graph twice, you need to pass retain_graph=true. | ||
train_mode: bool, optional | ||
Whether to do backward for training or predicting. | ||
""" | ||
backward | ||
|
||
""" | ||
grad(arr::NDArray) | ||
Returns gradient buffer attached to this `NDArray`. | ||
If the gradient buffer isn't attached yet, return `nothing`. | ||
""" | ||
function grad(arr::NDArray) | ||
out = Ref{mx.MX_handle}(C_NULL) | ||
@mxcall(:MXNDArrayGetGrad, (MX_handle, Ref{MX_handle}), arr.handle, out) | ||
(out[] == C_NULL) ? nothing: NDArray(MX_NDArrayHandle(out[])) | ||
end | ||
|
||
""" | ||
attach_grad(arr::NDArray, grad_req::Symbol=:write) | ||
Attach a gradient buffer to this `NDArray`, so that [`backward`](@ref) | ||
can compute gradient with respect to it. | ||
## Parameters | ||
- `arr::NDArray` | ||
- `grad_req::Symbol` (default is `:write`) | ||
## Return | ||
The attached gradient buffer | ||
## See also | ||
- [`grad`](@ref) | ||
""" | ||
function attach_grad(arr::NDArray, grad_req::Symbol=:write) | ||
# TODO: support storage type (stype in Python) | ||
# TODO: make sure it works with gpu array | ||
grad = zeros_like(arr) | ||
_mark_variables([arr], [grad], grad_req) | ||
grad | ||
end | ||
|
||
""" | ||
mark_variables(vars, grads) | ||
Mark NDArrays as variables to compute gradient for autograd. | ||
## Parameters | ||
variables: NDArray or list of NDArray | ||
gradients: NDArray or list of NDArray | ||
grad_reqs: str or list of str | ||
""" | ||
mark_variables(var::NDArray, grad::NDArray, grad_reqs::Symbol=:write) = | ||
_mark_variables([var], [grad], grad_reqs) | ||
|
||
mark_variables(var::Vector{NDArray}, grads::Vector{NDArray}, grad_reqs=:write) = | ||
_mark_variables(var, grads, grad_reqs) | ||
|
||
@inline function _mark_variables(vars::Vector{NDArray}, grads::Vector{NDArray}, | ||
grad_reqs::Union{Vector{Symbol}, Symbol}=:write) | ||
# TODO: leverage grad reqs map from #283 | ||
if length(vars) != length(grads) | ||
throw(ArgumentError("number of variables and gradients not matched")) | ||
end | ||
|
||
var_hdls = map(arr -> arr.handle, vars) | ||
grad_hdls = map(arr -> arr.handle, grads) | ||
|
||
if isa(grad_reqs, Symbol) | ||
grad_reqs = MX_uint[GRAD_WRITE for i ∈ 1:length(vars)] # FIXME | ||
else | ||
if length(vars) != length(grad_reqs) | ||
throw(ArgumentError("number of variables and gradients not matched")) | ||
end | ||
grad_reqs = MX_uint[GRAD_WRITE for i ∈ 1:length(vars)] # FIXME | ||
end | ||
|
||
@mxcall(:MXAutogradMarkVariables, | ||
(MX_uint, Ref{MX_handle}, Ptr{MX_uint}, Ref{MX_handle}), | ||
length(vars), var_hdls, MX_uint[GRAD_WRITE], | ||
grad_hdls) | ||
end | ||
|
||
""" | ||
get_symbol(arr) | ||
Retrieve recorded computation history as `SymbolicNode`. | ||
## Parameters | ||
* `x::NDArray`: Array representing the head of computation graph. | ||
## Returns | ||
The retrieved `Symbol`. | ||
""" | ||
function get_symbol(arr::NDArray) | ||
ref = Ref{MX_handle}(C_NULL) | ||
@mxcall(:MXAutogradGetSymbol, (MX_handle, Ref{MX_handle}), arr, ref) | ||
SymbolicNode(MX_SymbolHandle(ref[])) | ||
end | ||
|
||
############################################################################### | ||
# TODO: User-defined differentiable function | ||
############################################################################### |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,29 @@ | ||
module TestAutoGrad | ||
|
||
using Base.Test | ||
|
||
using MXNet | ||
|
||
|
||
function test_grad() | ||
info("AutoGrad::grad") | ||
|
||
info("AutoGrad::grad::unattached") | ||
@test nothing == mx.grad(mx.zeros(10)) | ||
|
||
info("AutoGrad::grad::attached") | ||
x = mx.NDArray([1 2; 3 4]) | ||
grad = mx.attach_grad(x) | ||
@test eltype(grad) ≡ Int | ||
@test copy(grad) == [0 0; 0 0] | ||
|
||
grad[:] = 42 | ||
@test copy(mx.grad(x)) == [42 42; 42 42] | ||
end | ||
|
||
|
||
@testset "AutoGrad Test" begin | ||
test_grad() | ||
end | ||
|
||
end # model TestAutoGrad |