Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add functionality for early stopping rounds. #193

Merged
merged 19 commits into from
Nov 17, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
19 commits
Select commit Hold shift + click to select a range
14fe274
add functionality for early stopping
david-sun-1 Oct 8, 2023
f70ee6d
remove version word
david-sun-1 Oct 8, 2023
7ed00c4
Merge pull request #1 from jahan-ai/early_stopping_rounds
david-sun-1 Oct 8, 2023
5c9baba
evaluation msg into a parsing function and add back evaluation to upd…
david-sun-1 Oct 11, 2023
1e9040b
Merge pull request #2 from jahan-ai/early_stopping_rounds
david-sun-1 Oct 11, 2023
3f7fef4
Updated the call to updateone! to pass in the watchlist so it can be …
wilan-wong-1 Oct 25, 2023
154ca63
Added comments, additional examples, fixed issues with watchlist orde…
wilan-wong-1 Oct 26, 2023
169d563
Added functionality to extract the best iteration round with examples…
wilan-wong-1 Oct 26, 2023
b564be5
Cleaned up some lingering test cases.
wilan-wong-1 Oct 26, 2023
7af7f75
Updated doc to include early stopping example.
wilan-wong-1 Oct 27, 2023
ec8d066
Added additional info on data types for watchlist
wilan-wong-1 Oct 27, 2023
0b8be97
Annotated OrderedDict to be more obvious.
wilan-wong-1 Oct 27, 2023
13e4b84
Included using statement for OrderedCollection
wilan-wong-1 Oct 27, 2023
3bee176
Merge pull request #3 from jahan-ai/early_stopping_rounds
wilan-wong-1 Oct 27, 2023
be3236b
Moved log message parsing to update! instead of updateone
wilan-wong-1 Nov 9, 2023
a24258a
Updated documentation and tests.
wilan-wong-1 Nov 9, 2023
61abfeb
Altered the XGBoost method definition to reflect exception states for…
wilan-wong-1 Nov 9, 2023
b23bffd
Created exception if extract_metric_value could not find a match when…
wilan-wong-1 Nov 9, 2023
375fbb0
Merge pull request #4 from jahan-ai/early_stopping_rounds
wilan-wong-1 Nov 9, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 41 additions & 0 deletions docs/src/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,7 @@ Unlike feature data, label data can be extracted after construction of the `DMat
[`XGBoost.getlabel`](@ref).



## Booster
The [`Booster`](@ref) object holds model data. They are created with training data. Internally
this is always a `DMatrix` but arguments will be automatically converted.
Expand Down Expand Up @@ -182,3 +183,43 @@ is equivalent to
bst = xgboost((X, y), num_round=10)
update!(bst, (X, y), num_round=10)
```

### Early Stopping
To help prevent overfitting to the training set, it is helpful to use a validation set to evaluate against to ensure that the XGBoost iterations continue to generalise outside training loss reduction. Early stopping provides a convenient way to automatically stop the
boosting process if it's observed that the generalisation capability of the model does not improve for `k` rounds.

If there is more than one element in watchlist, by default the last element will be used. In this case, you must use an ordered data structure (`OrderedDict`) compared to a standard unordered dictionary otherwise an exception will be generated. There will be
a warning if you want to execute early stopping mechanism (`early_stopping_rounds > 0`) but have provided a watchlist with type `Dict` with
more than 1 element.

Similarly, if there is more than one element in eval_metric, by default the last element will be used.

For example:

```julia
using LinearAlgebra
using OrderedCollections

𝒻(x) = 2norm(x)^2 - norm(x)

X = randn(100,3)
y = 𝒻.(eachrow(X))

dtrain = DMatrix((X, y))

X_valid = randn(50,3)
y_valid = 𝒻.(eachrow(X_valid))

dvalid = DMatrix((X_valid, y_valid))

bst = xgboost(dtrain, num_round = 100, eval_metric = "rmse", watchlist = OrderedDict(["train" => dtrain, "eval" => dvalid]), early_stopping_rounds = 5, max_depth=6, η=0.3)

# get the best iteration and use it for prediction
ŷ = predict(bst, X_valid, ntree_limit = bst.best_iteration)

using Statistics
println("RMSE from model prediction $(round((mean((ŷ - y_valid).^2).^0.5), digits = 8)).")

# we can also retain / use the best score (based on eval_metric) which is stored in the booster
println("Best RMSE from model training $(round((bst.best_score), digits = 8)).")
```
160 changes: 147 additions & 13 deletions src/booster.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@

"""
Booster

Expand Down Expand Up @@ -50,11 +49,17 @@ mutable struct Booster
# out what the hell is happening, it's never used for program logic
params::Dict{Symbol,Any}

function Booster(h::BoosterHandle, fsn::AbstractVector{<:AbstractString}=String[], params::AbstractDict=Dict())
# store early stopping information
best_iteration::Union{Int64, Missing}
best_score::Union{Float64, Missing}

function Booster(h::BoosterHandle, fsn::AbstractVector{<:AbstractString}=String[], params::AbstractDict=Dict(), best_iteration::Union{Int64, Missing}=missing,
best_score::Union{Float64, Missing}=missing)
finalizer(x -> xgbcall(XGBoosterFree, x.handle), new(h, fsn, params))
end
end


"""
setparam!(b::Booster, name, val)

Expand Down Expand Up @@ -366,7 +371,6 @@ function updateone!(b::Booster, Xy::DMatrix;
update_feature_names::Bool=false,
)
xgbcall(XGBoosterUpdateOneIter, b.handle, round_number, Xy.handle)
isempty(watchlist) || logeval(b, watchlist, round_number)
_maybe_update_feature_names!(b, Xy, update_feature_names)
b
end
Expand All @@ -382,7 +386,6 @@ function updateone!(b::Booster, Xy::DMatrix, g::AbstractVector{<:Real}, h::Abstr
g = convert(Vector{Cfloat}, g)
h = convert(Vector{Cfloat}, h)
xgbcall(XGBoosterBoostOneIter, b.handle, Xy.handle, g, h, length(g))
isempty(watchlist) || logeval(b, watchlist, round_number)
_maybe_update_feature_names!(b, Xy, update_feature_names)
b
end
Expand Down Expand Up @@ -422,14 +425,105 @@ Run `num_round` rounds of gradient boosting on [`Booster`](@ref) `b`.
The first and second derivatives of the loss function (`ℓ′` and `ℓ″` respectively) can be provided
for custom loss.
"""
function update!(b::Booster, data, a...; num_round::Integer=1, kw...)
function update!(b::Booster, data, a...;
num_round::Integer=1,
watchlist::Any = Dict("train" => data),
early_stopping_rounds::Integer=0,
maximize=false,
kw...,
)

if !isempty(watchlist) && early_stopping_rounds > 0
@info("Will train until there has been no improvement in $early_stopping_rounds rounds.\n")
best_round = 0
best_score = maximize ? -Inf : Inf
end

for j ∈ 1:num_round
round_number = getnrounds(b) + 1
updateone!(b, data, a...; round_number, kw...)

updateone!(b, data, a...; round_number, watchlist, kw...)

# Evaluate if watchlist is not empty
if !isempty(watchlist)
msg = evaliter(b, watchlist, round_number)
@info msg
if early_stopping_rounds > 0
score, dataset, metric = extract_metric_value(msg)
if (maximize && score > best_score || (!maximize && score < best_score))
best_score = score
best_round = j
elseif j - best_round >= early_stopping_rounds
@info(
"Xgboost: Stopping. \n\tBest iteration: $best_round. \n\tNo improvement in $dataset-$metric result in $early_stopping_rounds rounds."
)
# add additional fields to record the best iteration
b.best_iteration = best_round
b.best_score = best_score
return b
end
end
end
end
b
end



"""
extract_metric_value(msg, dataset=nothing, metric=nothing)

Extracts a numeric value from a message based on the specified dataset and metric.
If dataset or metric is not provided, the function will automatically find the last
mentioned dataset or metric in the message.

# Arguments
- `msg::AbstractString`: The message containing the numeric values.
- `dataset::Union{AbstractString, Nothing}`: The dataset to extract values for (default: `nothing`).
- `metric::Union{AbstractString, Nothing}`: The metric to extract values for (default: `nothing`).

# Returns
- Returns the parsed Float64 value if a match is found, otherwise returns `nothing`.

# Examples
```julia
msg = "train-rmsle:0.09516384803222511 train-rmse:0.12458323318968342 eval-rmsle:0.09311178520817574 eval-rmse:0.12088154560829874"

# Without specifying dataset and metric
value_without_params = extract_metric_value(msg)
println(value_without_params) # Output: (0.09311178520817574, "eval", "rmsle")

# With specifying dataset and metric
value_with_params = extract_metric_value(msg, "train", "rmsle")
println(value_with_params) # Output: (0.0951638480322251, "train", "rmsle")
"""

function extract_metric_value(msg, dataset=nothing, metric=nothing)
if isnothing(dataset)
# Find the last mentioned dataset - whilst retaining order
datasets = unique([m.match for m in eachmatch(r"\w+(?=-)", msg)])
dataset = last(collect(datasets))
end

if isnothing(metric)
# Find the last mentioned metric - whilst retaining order
metrics = unique([m.match for m in eachmatch(r"(?<=-)\w+", msg)])
metric = last(collect(metrics))
end

pattern = Regex("$dataset-$metric:([\\d.]+)")

match_result = match(pattern, msg)

if match_result != nothing
parsed_value = parse(Float64, match_result.captures[1])
ExpandingMan marked this conversation as resolved.
Show resolved Hide resolved
return parsed_value, dataset, metric
end

# there was no match result - should error out
error("No match found for pattern: $dataset-$metric in message: $msg")
end

"""
xgboost(data; num_round=10, watchlist=Dict(), kw...)
xgboost(data, ℓ′, ℓ″; kw...)
Expand All @@ -439,7 +533,19 @@ This is essentially an alias for constructing a [`Booster`](@ref) with `data` an
followed by [`update!`](@ref) for `nrounds`.

`watchlist` is a dict the keys of which are strings giving the name of the data to watch
and the values of which are [`DMatrix`](@ref) objects containing the data.
and the values of which are [`DMatrix`](@ref) objects containing the data. It is mandatory to use an OrderedDict
when utilising early_stopping_rounds and there is more than 1 element in watchlist to ensure XGBoost uses the
correct and intended dataset to perform early stop.

`early_stopping_rounds` activates early stopping if set to > 0. Validation metric needs to improve at
least once in every k rounds. If `watchlist` is not explicitly provided, it will use the training dataset
to evaluate the stopping criterion. Otherwise, it will use the last data element in `watchlist` and the
last metric in `eval_metric` (if more than one). Note that `watchlist` cannot be empty if
`early_stopping_rounds` is enabled.

`maximize` If early_stopping_rounds is set, then this parameter must be set as well.
When it is false, it means the smaller the evaluation score the better. When set to true,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems to imply that the user must input the value whenever they input early_stopping_rounds. Maybe better to say something like "only used if early_stopping_rounds > 0".

the larger the evaluation score the better.

All other keyword arguments are passed to [`Booster`](@ref). With few exceptions these are model
training hyper-parameters, see [here](https://xgboost.readthedocs.io/en/stable/parameter.html) for
Expand All @@ -450,23 +556,51 @@ See [`updateone!`](@ref) for more details.

## Examples
```julia
# Example 1: Basic usage of XGBoost
(X, y) = (randn(100,3), randn(100))

b = xgboost((X, y), 10, max_depth=10, η=0.1)
b = xgboost((X, y), num_round=10, max_depth=10, η=0.1)

ŷ = predict(b, X)

# Example 2: Using early stopping (using a validation set) with a watchlist
dtrain = DMatrix((randn(100,3), randn(100)))
dvalid = DMatrix((randn(100,3), randn(100)))

watchlist = OrderedDict(["train" => dtrain, "valid" => dvalid])

b = xgboost(dtrain, num_round=10, early_stopping_rounds = 2, watchlist = watchlist, max_depth=10, η=0.1)

# note that ntree_limit in the predict function helps assign the upper bound for iteration_range in the XGBoost API 1.4+
ŷ = predict(b, dvalid, ntree_limit = b.best_iteration)
```
"""
function xgboost(dm::DMatrix, a...;
num_round::Integer=10,
watchlist=Dict("train"=>dm),
kw...
)
num_round::Integer=10,
watchlist::AbstractDict = Dict("train" => dm),
early_stopping_rounds::Integer=0,
maximize=false,
kw...
)

Xy = DMatrix(dm)
b = Booster(Xy; kw...)

# We have a watchlist - give a warning if early stopping is provided and watchlist is a Dict type with length > 1
if isa(watchlist, Dict)
if early_stopping_rounds > 0 && length(watchlist) > 1
error("You must supply an OrderedDict type for watchlist if early stopping rounds is enabled and there is more than one element in watchlist.")
end
end

if isempty(watchlist) && early_stopping_rounds > 0
error("Watchlist must be supplied if early_stopping_rounds is enabled.")
end

isempty(watchlist) || @info("XGBoost: starting training.")
update!(b, Xy, a...; num_round, watchlist)
update!(b, Xy, a...; num_round, watchlist, early_stopping_rounds, maximize)
isempty(watchlist) || @info("Training rounds complete.")
b
end

xgboost(data, a...; kw...) = xgboost(DMatrix(data), a...; kw...)
Loading
Loading