diff --git a/src/autograd.jl b/src/autograd.jl index 6a98b2f9a..e7ce335e2 100644 --- a/src/autograd.jl +++ b/src/autograd.jl @@ -48,36 +48,32 @@ end _set_training(::Void) = nothing -""" -Get status on recording/not recording. +############################################################################### +# Public API +############################################################################### -## Returns +""" + is_recording()::Bool -Current state of recording. +Get status on recording/not recording. """ -function _is_recording()::Bool +function is_recording()::Bool state = Ref{Cint}(C_NULL) @mxcall(:MXAutogradIsRecording, (Ref{Cint},), state) state[] end """ -Get status on recording/not recording. - -## Returns + is_training()::Bool -Current state of recording. +Get status on recording/not recording. """ -function _is_training()::Bool +function is_training()::Bool state = Ref{Cint}(C_NULL) @mxcall(:MXAutogradIsTraining, (Ref{Cint},), state) state[] end -############################################################################### -# Public API -############################################################################### - @inline function _record(f, is_record::Union{Void,Bool}, train_mode::Union{Void,Bool}) # Port from Python's `_RecordingStateScope` context manager # __enter__ @@ -247,9 +243,9 @@ function backward!(heads::VecOfNDArray, head_grads::Vector; retain_graph::Bool = false, train_mode::Bool = true) output_handles = map(x -> x.handle, heads) ograd_handles = map(head_grads) do x - if isa(x, NDArray) + if x isa NDArray arr.handle - elseif isa(x, Void) + elseif x isa Void MX_handle(C_NULL) else throw(ArgumentError("element of head_grads should be NDArray or Void")) diff --git a/test/unittest/autograd.jl b/test/unittest/autograd.jl index 3a484ae1f..d56730475 100644 --- a/test/unittest/autograd.jl +++ b/test/unittest/autograd.jl @@ -102,6 +102,26 @@ function test_record() end # function test_record +function test_is_recording() + info("AutoGrad::is_recording") + mx.record() do + @test is_recording() + end +end # function test_is_recording + + +function test_is_training() + info("AutoGrad::is_training") + mx.record() do + @test is_training() + end + + mx.record(false) do + @test !is_training() + end +end # function test_is_training + + function test_pause() info("AutoGrad::pause") let x = mx.NDArray([1 2; 3 4]) @@ -284,6 +304,8 @@ end # function test_div test_getgrad() test_mark_variables!() test_record() + test_is_recording() + test_is_training() test_pause() test_train_mode() test_predict_mode()