Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add functionality for early stopping rounds. #193

Merged
merged 19 commits into from
Nov 17, 2023
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
Show all changes
19 commits
Select commit Hold shift + click to select a range
14fe274
add functionality for early stopping
david-sun-1 Oct 8, 2023
f70ee6d
remove version word
david-sun-1 Oct 8, 2023
7ed00c4
Merge pull request #1 from jahan-ai/early_stopping_rounds
david-sun-1 Oct 8, 2023
5c9baba
evaluation msg into a parsing function and add back evaluation to upd…
david-sun-1 Oct 11, 2023
1e9040b
Merge pull request #2 from jahan-ai/early_stopping_rounds
david-sun-1 Oct 11, 2023
3f7fef4
Updated the call to updateone! to pass in the watchlist so it can be …
wilan-wong-1 Oct 25, 2023
154ca63
Added comments, additional examples, fixed issues with watchlist orde…
wilan-wong-1 Oct 26, 2023
169d563
Added functionality to extract the best iteration round with examples…
wilan-wong-1 Oct 26, 2023
b564be5
Cleaned up some lingering test cases.
wilan-wong-1 Oct 26, 2023
7af7f75
Updated doc to include early stopping example.
wilan-wong-1 Oct 27, 2023
ec8d066
Added additional info on data types for watchlist
wilan-wong-1 Oct 27, 2023
0b8be97
Annotated OrderedDict to be more obvious.
wilan-wong-1 Oct 27, 2023
13e4b84
Included using statement for OrderedCollection
wilan-wong-1 Oct 27, 2023
3bee176
Merge pull request #3 from jahan-ai/early_stopping_rounds
wilan-wong-1 Oct 27, 2023
be3236b
Moved log message parsing to update! instead of updateone
wilan-wong-1 Nov 9, 2023
a24258a
Updated documentation and tests.
wilan-wong-1 Nov 9, 2023
61abfeb
Altered the XGBoost method definition to reflect exception states for…
wilan-wong-1 Nov 9, 2023
b23bffd
Created exception if extract_metric_value could not find a match when…
wilan-wong-1 Nov 9, 2023
375fbb0
Merge pull request #4 from jahan-ai/early_stopping_rounds
wilan-wong-1 Nov 9, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
108 changes: 101 additions & 7 deletions src/booster.jl
Original file line number Diff line number Diff line change
Expand Up @@ -366,9 +366,10 @@ function updateone!(b::Booster, Xy::DMatrix;
update_feature_names::Bool=false,
)
xgbcall(XGBoosterUpdateOneIter, b.handle, round_number, Xy.handle)
isempty(watchlist) || logeval(b, watchlist, round_number)
isempty(watchlist) || (msg = evaliter(b, watchlist, round_number))
@info msg
_maybe_update_feature_names!(b, Xy, update_feature_names)
b
b, msg
ExpandingMan marked this conversation as resolved.
Show resolved Hide resolved
end

function updateone!(b::Booster, Xy::DMatrix, g::AbstractVector{<:Real}, h::AbstractVector{<:Real};
Expand All @@ -382,9 +383,10 @@ function updateone!(b::Booster, Xy::DMatrix, g::AbstractVector{<:Real}, h::Abstr
g = convert(Vector{Cfloat}, g)
h = convert(Vector{Cfloat}, h)
xgbcall(XGBoosterBoostOneIter, b.handle, Xy.handle, g, h, length(g))
isempty(watchlist) || logeval(b, watchlist, round_number)
isempty(watchlist) || (msg = evaliter(b, watchlist, round_number))
@info msg
_maybe_update_feature_names!(b, Xy, update_feature_names)
b
b, msg
ExpandingMan marked this conversation as resolved.
Show resolved Hide resolved
end

"""
Expand Down Expand Up @@ -422,14 +424,97 @@ Run `num_round` rounds of gradient boosting on [`Booster`](@ref) `b`.
The first and second derivatives of the loss function (`ℓ′` and `ℓ″` respectively) can be provided
for custom loss.
"""
function update!(b::Booster, data, a...; num_round::Integer=1, kw...)
function update!(b::Booster, data, a...;
num_round::Integer=1,
watchlist=Dict("train"=>Xy),
early_stopping_rounds::Integer=0,
maximize=false,
kw...,
)

if !isempty(watchlist) && early_stopping_rounds > 0
@info("Will train until there has been no improvement in $early_stopping_rounds rounds.\n")
best_round = 0
best_score = maximize ? -Inf : Inf
end

for j ∈ 1:num_round
round_number = getnrounds(b) + 1
updateone!(b, data, a...; round_number, kw...)
b, msg = updateone!(b, data, a...; round_number, kw...)
ExpandingMan marked this conversation as resolved.
Show resolved Hide resolved
if !isempty(watchlist) && early_stopping_rounds > 0
score, dataset, metric = extract_metric_value(msg)
if (maximize && score > best_score || (!maximize && score < best_score))
best_score = score
best_round = j
elseif j - best_round >= early_stopping_rounds
@info(
"Xgboost: Stopping. \n\tBest iteration: $best_round. \n\tNo improvement in $dataset-$metric result in $early_stopping_rounds rounds."
)
return (b)
end
end
end
b
end



"""
extract_metric_value(msg, dataset=nothing, metric=nothing)

Extracts a numeric value from a message based on the specified dataset and metric.
If dataset or metric is not provided, the function will automatically find the last
mentioned dataset or metric in the message.

# Arguments
- `msg::AbstractString`: The message containing the numeric values.
- `dataset::Union{AbstractString, Nothing}`: The dataset to extract values for (default: `nothing`).
- `metric::Union{AbstractString, Nothing}`: The metric to extract values for (default: `nothing`).

# Returns
- Returns the parsed Float64 value if a match is found, otherwise returns `nothing`.

# Examples
```julia
msg = "train-rmsle:0.09516384803222511 train-rmse:0.12458323318968342 eval-rmsle:0.09311178520817574 eval-rmse:0.12088154560829874"

# Without specifying dataset and metric
value_without_params = extract_metric_value(msg)
println(value_without_params) # Output: (0.09311178520817574, "eval", "rmsle")

# With specifying dataset and metric
value_with_params = extract_metric_value(msg, "train", "rmsle")
println(value_with_params) # Output: (0.0951638480322251, "train", "rmsle")
"""

function extract_metric_value(msg, dataset=nothing, metric=nothing)
if isnothing(dataset)
# Find the last mentioned dataset
datasets = Set([m.match for m in eachmatch(r"\w+(?=-)", msg)])
dataset = last(collect(datasets))
end

if isnothing(metric)
# Find the first mentioned metric
metrics = Set([m.match for m in eachmatch(r"(?<=-)\w+", msg)])
metric = last(collect(metrics))
end

pattern = Regex("$dataset-$metric:([\\d.]+)")

match_result = match(pattern, msg)

if match_result != nothing
ExpandingMan marked this conversation as resolved.
Show resolved Hide resolved
parsed_value = parse(Float64, match_result.captures[1])
return parsed_value, dataset, metric
else
@warn "No match found for pattern: $dataset-$metric in message: $msg"
return nothing
end
end



"""
xgboost(data; num_round=10, watchlist=Dict(), kw...)
xgboost(data, ℓ′, ℓ″; kw...)
Expand All @@ -441,6 +526,13 @@ followed by [`update!`](@ref) for `nrounds`.
`watchlist` is a dict the keys of which are strings giving the name of the data to watch
and the values of which are [`DMatrix`](@ref) objects containing the data.

`early_stopping_rounds` if 0, the early stopping function is not triggered. If set to a positive integer,
training with a validation set will stop if the performance doesn't improve for k rounds.

`maximize` If early_stopping_rounds is set, then this parameter must be set as well.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems to imply that the user must input the value whenever they input early_stopping_rounds. Maybe better to say something like "only used if early_stopping_rounds > 0".

When it is false, it means the smaller the evaluation score the better. When set to true,
the larger the evaluation score the better.

All other keyword arguments are passed to [`Booster`](@ref). With few exceptions these are model
training hyper-parameters, see [here](https://xgboost.readthedocs.io/en/stable/parameter.html) for
a comprehensive list.
Expand All @@ -460,12 +552,14 @@ ŷ = predict(b, X)
function xgboost(dm::DMatrix, a...;
num_round::Integer=10,
watchlist=Dict("train"=>dm),
early_stopping_rounds::Integer=0,
maximize=false,
kw...
)
Xy = DMatrix(dm)
b = Booster(Xy; kw...)
isempty(watchlist) || @info("XGBoost: starting training.")
update!(b, Xy, a...; num_round, watchlist)
update!(b, Xy, a...; num_round, watchlist, early_stopping_rounds, maximize)
isempty(watchlist) || @info("Training rounds complete.")
b
end
Expand Down
35 changes: 35 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,41 @@ end
end
end


@testset "Early Stopping rounds" begin

dtrain = XGBoost.load(DMatrix, testfilepath("agaricus.txt.train"), format=:libsvm)
dtest = XGBoost.load(DMatrix, testfilepath("agaricus.txt.test"), format=:libsvm)
watchlist = Dict("eval"=>dtest, "train"=>dtrain)

bst = xgboost(dtrain,
num_round=30,
watchlist=watchlist,
η=1,
objective="binary:logistic",
eval_metric=["rmsle","rmse"]
)

bst_early_stopping = xgboost(dtrain,
num_round=30,
watchlist=watchlist,
η=1,
objective="binary:logistic",
eval_metric=["rmsle","rmse"],
early_stopping_rounds = 2
)

nrounds_bst = XGBoost.getnrounds(bst)
nrounds_bst_early_stopping = XGBoost.getnrounds(bst_early_stopping)
# Check to see that running with early stopping results in less rounds
@test nrounds_bst_early_stopping < nrounds_bst

# Check number of rounds > early stopping rounds
@test nrounds_bst_early_stopping > 2
end



@testset "Blobs training" begin
(X, y) = load_classification()

Expand Down
Loading