Skip to content

Commit

Permalink
is_recording/is_training
Browse files Browse the repository at this point in the history
  • Loading branch information
iblislin committed Dec 9, 2017
1 parent 853ca31 commit 3c7356d
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 16 deletions.
28 changes: 12 additions & 16 deletions src/autograd.jl
Original file line number Diff line number Diff line change
Expand Up @@ -48,36 +48,32 @@ end

_set_training(::Void) = nothing

"""
Get status on recording/not recording.
###############################################################################
# Public API
###############################################################################

## Returns
"""
is_recording()::Bool
Current state of recording.
Get status on recording/not recording.
"""
function _is_recording()::Bool
function is_recording()::Bool
state = Ref{Cint}(C_NULL)
@mxcall(:MXAutogradIsRecording, (Ref{Cint},), state)
state[]
end

"""
Get status on recording/not recording.
## Returns
is_training()::Bool
Current state of recording.
Get status on recording/not recording.
"""
function _is_training()::Bool
function is_training()::Bool
state = Ref{Cint}(C_NULL)
@mxcall(:MXAutogradIsTraining, (Ref{Cint},), state)
state[]
end

###############################################################################
# Public API
###############################################################################

@inline function _record(f, is_record::Union{Void,Bool}, train_mode::Union{Void,Bool})
# Port from Python's `_RecordingStateScope` context manager
# __enter__
Expand Down Expand Up @@ -247,9 +243,9 @@ function backward!(heads::VecOfNDArray, head_grads::Vector;
retain_graph::Bool = false, train_mode::Bool = true)
output_handles = map(x -> x.handle, heads)
ograd_handles = map(head_grads) do x
if isa(x, NDArray)
if x isa NDArray
arr.handle
elseif isa(x, Void)
elseif x isa Void
MX_handle(C_NULL)
else
throw(ArgumentError("element of head_grads should be NDArray or Void"))
Expand Down
22 changes: 22 additions & 0 deletions test/unittest/autograd.jl
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,26 @@ function test_record()
end # function test_record


function test_is_recording()
info("AutoGrad::is_recording")
mx.record() do
@test is_recording()
end
end # function test_is_recording


function test_is_training()
info("AutoGrad::is_training")
mx.record() do
@test is_training()
end

mx.record(false) do
@test !is_training()
end
end # function test_is_training


function test_pause()
info("AutoGrad::pause")
let x = mx.NDArray([1 2; 3 4])
Expand Down Expand Up @@ -284,6 +304,8 @@ end # function test_div
test_getgrad()
test_mark_variables!()
test_record()
test_is_recording()
test_is_training()
test_pause()
test_train_mode()
test_predict_mode()
Expand Down

0 comments on commit 3c7356d

Please sign in to comment.