diff --git a/src/Nn/Nn.jl b/src/Nn/Nn.jl index 253fa56b..ba3d0338 100644 --- a/src/Nn/Nn.jl +++ b/src/Nn/Nn.jl @@ -541,17 +541,20 @@ abstract type OptimisationAlgorithm end include("Nn_default_optalgs.jl") """ - fitting_info(nn,x,y;n,batch_size,epochs,verbosity,n_epoch,n_batch) + fitting_info(nn,xbatch,ybatch,x,y;n,batch_size,epochs,epochs_ran,verbosity,n_epoch,n_batch) Default callback funtion to display information during training, depending on the verbosity level # Parameters: * `nn`: Worker network -* `x`: Batch input to the network (batch_size,d) -* `y`: Batch label input (batch_size,d) +* `xbatch`: Batch input to the network (batch_size,din) +* `ybatch`: Batch label input (batch_size,dout) +* `x`: Full input to the network (n_records,din) +* `y`: Full label input (n_records,dout) * `n`: Size of the full training set * `n_batches` : Number of baches per epoch * `epochs`: Number of epochs defined for the training +* `epochs_ran`: Number of epochs already ran in previous training sessions * `verbosity`: Verbosity level defined for the training (NONE,LOW,STD,HIGH,FULL) * `n_epoch`: Counter of the current epoch * `n_batch`: Counter of the current batch @@ -559,19 +562,18 @@ Default callback funtion to display information during training, depending on th #Notes: * Reporting of the error (loss of the network) is expensive. Use `verbosity=NONE` for better performances """ -function fitting_info(nn,x,y;n,n_batches,epochs,verbosity,n_epoch,n_batch) +function fitting_info(nn,xbatch,ybatch,x,y;n,n_batches,epochs,epochs_ran,verbosity,n_epoch,n_batch) if verbosity == NONE return false # doesn't stop the training end nMsgDict = Dict(LOW => 0, STD => 10,HIGH => 100, FULL => n) nMsgs = nMsgDict[verbosity] - batch_size = size(x,1) if verbosity == FULL || ( n_batch == n_batches && ( n_epoch == 1 || n_epoch % ceil(epochs/nMsgs) == 0)) ϵ = loss(nn,x,y) - println("Training.. \t avg ϵ on (Epoch $n_epoch Batch $n_batch): \t $(ϵ)") + println("Training.. \t avg loss on epoch $n_epoch ($(n_epoch+epochs_ran)): \t $(ϵ)") end return false end @@ -614,7 +616,7 @@ Low leval function that trains a neural network with the given x,y data. - The verbosity can be set to any of `NONE`,`LOW`,`STD`,`HIGH`,`FULL`. - The update is done computing the average gradient for each batch and then calling `single_update!` to let the optimisation algorithm perform the parameters update """ -function train!(nn::NN,x,y; epochs=100, batch_size=min(size(x,1),32), sequential=false, verbosity::Verbosity=STD, cb=fitting_info, opt_alg::OptimisationAlgorithm=ADAM(),rng = Random.GLOBAL_RNG)#, η=t -> 1/(1+t), λ=1, rShuffle=true, nMsgs=10, tol=0opt_alg::SD=SD()) +function train!(nn::NN,x,y; epochs=100, batch_size=min(size(x,1),32), sequential=false, nepochs_ran=0,verbosity::Verbosity=STD, cb=fitting_info, opt_alg::OptimisationAlgorithm=ADAM(),rng = Random.GLOBAL_RNG)#, η=t -> 1/(1+t), λ=1, rShuffle=true, nMsgs=10, tol=0opt_alg::SD=SD()) if verbosity > STD @codelocation end @@ -671,9 +673,9 @@ function train!(nn::NN,x,y; epochs=100, batch_size=min(size(x,1),32), sequential #println("****foooo") #println(▽) - res = single_update!(θ,▽;n_epoch=t,n_batch=i,n_batches=n_batches,xbatch=xbatch,ybatch=ybatch,opt_alg=opt_alg) + res = single_update!(θ,▽;n_epoch=t+nepochs_ran,n_batch=i,n_batches=n_batches,xbatch=xbatch,ybatch=ybatch,opt_alg=opt_alg) set_params!(nn,res.θ) - cbOut = cb(nn,xbatch,ybatch,n=d,n_batches=n_batches,epochs=epochs,verbosity=verbosity,n_epoch=t,n_batch=i) + cbOut = cb(nn,xbatch,ybatch,x,y,n=d,n_batches=n_batches,epochs=epochs,epochs_ran=nepochs_ran,verbosity=verbosity,n_epoch=t,n_batch=i) if(res.stop==true || cbOut==true) nn.trained = true return (epochs=t,ϵ_epochs=ϵ_epochs,θ_epochs=θ_epochs) @@ -1062,7 +1064,7 @@ function fit!(m::NeuralNetworkEstimator,X,Y) nnstruct = m.par.nnstruct - out = train!(nnstruct,X,Y; epochs=epochs, batch_size=batch_size, sequential=!shuffle, verbosity=verbosity, cb=cb, opt_alg=opt_alg,rng = rng) + out = train!(nnstruct,X,Y; epochs=epochs, batch_size=batch_size, sequential=!shuffle, verbosity=verbosity, cb=cb, opt_alg=opt_alg,nepochs_ran=m.info["nepochs_ran"],rng = rng) m.info["nepochs_ran"] += out.epochs append!(m.info["loss_per_epoch"],out.ϵ_epochs)