Skip to content

Commit

Permalink
[breaking] NN optimizer update accounts for total number of epochs, m…
Browse files Browse the repository at this point in the history
…ode data and better info in training callback

1. epochs ran in previous training are now accounded in the optizer update function (e.g. to reduce the step at each epoch) instead
   of departing from 1 at each fit!
2. the number of epochs already ran and the whole dataset (x,y) (and not only xbatch, ybatch) is given to the callback ran at each
   update. The default calback (`fitting_info`) then use this info to report the whole dataset loss at each epoch
  • Loading branch information
sylvaticus committed Sep 6, 2023
1 parent 4e19c54 commit 9f076c5
Showing 1 changed file with 12 additions and 10 deletions.
22 changes: 12 additions & 10 deletions src/Nn/Nn.jl
Original file line number Diff line number Diff line change
Expand Up @@ -541,37 +541,39 @@ abstract type OptimisationAlgorithm end
include("Nn_default_optalgs.jl")

"""
fitting_info(nn,x,y;n,batch_size,epochs,verbosity,n_epoch,n_batch)
fitting_info(nn,xbatch,ybatch,x,y;n,batch_size,epochs,epochs_ran,verbosity,n_epoch,n_batch)
Default callback funtion to display information during training, depending on the verbosity level
# Parameters:
* `nn`: Worker network
* `x`: Batch input to the network (batch_size,d)
* `y`: Batch label input (batch_size,d)
* `xbatch`: Batch input to the network (batch_size,din)
* `ybatch`: Batch label input (batch_size,dout)
* `x`: Full input to the network (n_records,din)
* `y`: Full label input (n_records,dout)
* `n`: Size of the full training set
* `n_batches` : Number of baches per epoch
* `epochs`: Number of epochs defined for the training
* `epochs_ran`: Number of epochs already ran in previous training sessions
* `verbosity`: Verbosity level defined for the training (NONE,LOW,STD,HIGH,FULL)
* `n_epoch`: Counter of the current epoch
* `n_batch`: Counter of the current batch
#Notes:
* Reporting of the error (loss of the network) is expensive. Use `verbosity=NONE` for better performances
"""
function fitting_info(nn,x,y;n,n_batches,epochs,verbosity,n_epoch,n_batch)
function fitting_info(nn,xbatch,ybatch,x,y;n,n_batches,epochs,epochs_ran,verbosity,n_epoch,n_batch)
if verbosity == NONE
return false # doesn't stop the training
end

nMsgDict = Dict(LOW => 0, STD => 10,HIGH => 100, FULL => n)
nMsgs = nMsgDict[verbosity]
batch_size = size(x,1)

if verbosity == FULL || ( n_batch == n_batches && ( n_epoch == 1 || n_epoch % ceil(epochs/nMsgs) == 0))

ϵ = loss(nn,x,y)
println("Training.. \t avg ϵ on (Epoch $n_epoch Batch $n_batch): \t $(ϵ)")
println("Training.. \t avg loss on epoch $n_epoch ($(n_epoch+epochs_ran)): \t $(ϵ)")
end
return false
end
Expand Down Expand Up @@ -614,7 +616,7 @@ Low leval function that trains a neural network with the given x,y data.
- The verbosity can be set to any of `NONE`,`LOW`,`STD`,`HIGH`,`FULL`.
- The update is done computing the average gradient for each batch and then calling `single_update!` to let the optimisation algorithm perform the parameters update
"""
function train!(nn::NN,x,y; epochs=100, batch_size=min(size(x,1),32), sequential=false, verbosity::Verbosity=STD, cb=fitting_info, opt_alg::OptimisationAlgorithm=ADAM(),rng = Random.GLOBAL_RNG)#, η=t -> 1/(1+t), λ=1, rShuffle=true, nMsgs=10, tol=0opt_alg::SD=SD())
function train!(nn::NN,x,y; epochs=100, batch_size=min(size(x,1),32), sequential=false, nepochs_ran=0,verbosity::Verbosity=STD, cb=fitting_info, opt_alg::OptimisationAlgorithm=ADAM(),rng = Random.GLOBAL_RNG)#, η=t -> 1/(1+t), λ=1, rShuffle=true, nMsgs=10, tol=0opt_alg::SD=SD())
if verbosity > STD
@codelocation
end
Expand Down Expand Up @@ -671,9 +673,9 @@ function train!(nn::NN,x,y; epochs=100, batch_size=min(size(x,1),32), sequential
#println("****foooo")
#println(▽)

res = single_update!(θ,▽;n_epoch=t,n_batch=i,n_batches=n_batches,xbatch=xbatch,ybatch=ybatch,opt_alg=opt_alg)
res = single_update!(θ,▽;n_epoch=t+nepochs_ran,n_batch=i,n_batches=n_batches,xbatch=xbatch,ybatch=ybatch,opt_alg=opt_alg)
set_params!(nn,res.θ)
cbOut = cb(nn,xbatch,ybatch,n=d,n_batches=n_batches,epochs=epochs,verbosity=verbosity,n_epoch=t,n_batch=i)
cbOut = cb(nn,xbatch,ybatch,x,y,n=d,n_batches=n_batches,epochs=epochs,epochs_ran=nepochs_ran,verbosity=verbosity,n_epoch=t,n_batch=i)
if(res.stop==true || cbOut==true)
nn.trained = true
return (epochs=t,ϵ_epochs=ϵ_epochs,θ_epochs=θ_epochs)
Expand Down Expand Up @@ -1062,7 +1064,7 @@ function fit!(m::NeuralNetworkEstimator,X,Y)
nnstruct = m.par.nnstruct


out = train!(nnstruct,X,Y; epochs=epochs, batch_size=batch_size, sequential=!shuffle, verbosity=verbosity, cb=cb, opt_alg=opt_alg,rng = rng)
out = train!(nnstruct,X,Y; epochs=epochs, batch_size=batch_size, sequential=!shuffle, verbosity=verbosity, cb=cb, opt_alg=opt_alg,nepochs_ran=m.info["nepochs_ran"],rng = rng)

m.info["nepochs_ran"] += out.epochs
append!(m.info["loss_per_epoch"],out.ϵ_epochs)
Expand Down

0 comments on commit 9f076c5

Please sign in to comment.