Skip to content

Commit

Permalink
issues seem to be largely related to logging
Browse files Browse the repository at this point in the history
  • Loading branch information
pat-alt committed Jul 1, 2024
1 parent 8389660 commit e61af35
Showing 1 changed file with 17 additions and 2 deletions.
19 changes: 17 additions & 2 deletions src/mlj_flux.jl
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,8 @@ function MLJFlux.fitresult(model::LaplaceRegression, chain, y)
else
target_column_names = Tables.schema(y).names
end
@info "From fitresult"
println(chain)
return (chain, deepcopy(model))
end

Expand Down Expand Up @@ -254,6 +256,12 @@ function MLJFlux.train(
)
verbosity != 1 || next!(meter)

# initiate history:
loss = model.loss
n_batches = length(y)
losses = (loss(chain(X[i]), y[i]) for i in 1:n_batches)
history = [mean(losses)]

for i in 1:epochs
chain, optimiser_state, current_loss = MLJFlux.train_epoch(
model, chain, regularized_optimiser, optimiser_state, X, y
Expand Down Expand Up @@ -281,6 +289,8 @@ function MLJFlux.train(
move,
)

@info "From train"
println(chain)
fitresult = MLJFlux.fitresult(model, Flux.cpu(chain), y)

report = history
Expand Down Expand Up @@ -499,8 +509,6 @@ function MLJFlux.train(
)
verbose_laplace = false

# Initialize history:
history = []
# intitialize and start progress meter:
meter = Progress(
model.epochs + 1;
Expand All @@ -510,6 +518,13 @@ function MLJFlux.train(
barlen=25,
color=:yellow,
)

# initiate history:
loss = model.loss
n_batches = length(y)
losses = (loss(chain(X[i]), y[i]) for i in 1:n_batches)
history = [mean(losses)]

for i in 1:epochs
chain, optimiser_state, current_loss = MLJFlux.train_epoch(
model, chain, regularized_optimiser, optimiser_state, X, y
Expand Down

0 comments on commit e61af35

Please sign in to comment.