Skip to content

Commit

Permalink
Moved log message parsing to update! instead of updateone
Browse files Browse the repository at this point in the history
  • Loading branch information
wilan-wong-1 committed Nov 9, 2023
1 parent 13e4b84 commit be3236b
Showing 1 changed file with 25 additions and 35 deletions.
60 changes: 25 additions & 35 deletions src/booster.jl
Original file line number Diff line number Diff line change
Expand Up @@ -371,16 +371,8 @@ function updateone!(b::Booster, Xy::DMatrix;
update_feature_names::Bool=false,
)
xgbcall(XGBoosterUpdateOneIter, b.handle, round_number, Xy.handle)
# obtain the logs if watchlist is present (for early stopping and/or info)
if isempty(watchlist)
msg = nothing
else
msg = evaliter(b, watchlist, round_number)
@info msg
end
#isempty(watchlist) || (msg = evaliter(b, watchlist, round_number))
_maybe_update_feature_names!(b, Xy, update_feature_names)
b, msg
b
end

function updateone!(b::Booster, Xy::DMatrix, g::AbstractVector{<:Real}, h::AbstractVector{<:Real};
Expand All @@ -394,16 +386,8 @@ function updateone!(b::Booster, Xy::DMatrix, g::AbstractVector{<:Real}, h::Abstr
g = convert(Vector{Cfloat}, g)
h = convert(Vector{Cfloat}, h)
xgbcall(XGBoosterBoostOneIter, b.handle, Xy.handle, g, h, length(g))
# obtain the logs if watchlist is present (for early stopping and/or info)
if isempty(watchlist)
msg = nothing
else
msg = evaliter(b, watchlist, round_number)
@info msg
end
#isempty(watchlist) || (msg = evaliter(b, watchlist, round_number))
_maybe_update_feature_names!(b, Xy, update_feature_names)
b, msg
b
end

"""
Expand Down Expand Up @@ -458,20 +442,26 @@ function update!(b::Booster, data, a...;
for j 1:num_round
round_number = getnrounds(b) + 1

b, msg = updateone!(b, data, a...; round_number, watchlist, kw...)
if !isempty(watchlist) && early_stopping_rounds > 0
score, dataset, metric = extract_metric_value(msg)
if (maximize && score > best_score || (!maximize && score < best_score))
best_score = score
best_round = j
elseif j - best_round >= early_stopping_rounds
@info(
"Xgboost: Stopping. \n\tBest iteration: $best_round. \n\tNo improvement in $dataset-$metric result in $early_stopping_rounds rounds."
)
# add additional fields to record the best iteration
b.best_iteration = best_round
b.best_score = best_score
return b
updateone!(b, data, a...; round_number, watchlist, kw...)

# Evaluate if watchlist is not empty
if !isempty(watchlist)
msg = evaliter(b, watchlist, round_number)
@info msg
if early_stopping_rounds > 0
score, dataset, metric = extract_metric_value(msg)
if (maximize && score > best_score || (!maximize && score < best_score))
best_score = score
best_round = j
elseif j - best_round >= early_stopping_rounds
@info(
"Xgboost: Stopping. \n\tBest iteration: $best_round. \n\tNo improvement in $dataset-$metric result in $early_stopping_rounds rounds."
)
# add additional fields to record the best iteration
b.best_iteration = best_round
b.best_score = best_score
return b
end
end
end
end
Expand Down Expand Up @@ -585,7 +575,7 @@ ŷ = predict(b, dvalid, ntree_limit = b.best_iteration)
"""
function xgboost(dm::DMatrix, a...;
num_round::Integer=10,
watchlist::Any = Dict("train" => dm),
watchlist::AbstractDict = Dict("train" => dm),
early_stopping_rounds::Integer=0,
maximize=false,
kw...
Expand All @@ -597,12 +587,12 @@ function xgboost(dm::DMatrix, a...;
# We have a watchlist - give a warning if early stopping is provided and watchlist is a Dict type with length > 1
if isa(watchlist, Dict)
if early_stopping_rounds > 0 && length(watchlist) > 1
@warn "Early stopping rounds activated whilst watchlist has more than 1 element. Recommended to provide watchlist as an OrderedDict to ensure deterministic behaviour."
@error "You must supply an OrderedDict type for watchlist if early stopping rounds is enabled."
end
end

if isempty(watchlist) && early_stopping_rounds > 0
@warn "Early stopping is ignored as provided watchlist is empty."
@error "Watchlist must be supplied if early_stopping_rounds is enabled."
end

isempty(watchlist) || @info("XGBoost: starting training.")
Expand Down

0 comments on commit be3236b

Please sign in to comment.