Skip to content

Commit

Permalink
Add functionality for early stopping rounds. (#193)
Browse files Browse the repository at this point in the history
* add functionality for early stopping

* remove version word

* evaluation msg into a parsing function and add back evaluation to updateone

* Updated the call to updateone! to pass in the watchlist so it can be used by early stopping round logic.

* Added comments, additional examples, fixed issues with watchlist ordering as a Dict.

* Added functionality to extract the best iteration round with examples. Included additional test case coverage.

* Cleaned up some lingering test cases.

* Updated doc to include early stopping example.

* Added additional info on data types for watchlist

* Annotated OrderedDict to be more obvious.

* Included using statement for OrderedCollection

* Moved log message parsing to update! instead of updateone

* Updated documentation and tests.

* Altered the XGBoost method definition to reflect exception states for early stopping rounds and watchlist.

* Created exception if extract_metric_value could not find a match when parsing XGBoost logs.

---------

Co-authored-by: Wilan Wong <[email protected]>
Co-authored-by: wilan-wong-1 <[email protected]>
  • Loading branch information
3 people authored Nov 17, 2023
1 parent c365c78 commit 4ead83f
Show file tree
Hide file tree
Showing 3 changed files with 305 additions and 13 deletions.
41 changes: 41 additions & 0 deletions docs/src/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,7 @@ Unlike feature data, label data can be extracted after construction of the `DMat
[`XGBoost.getlabel`](@ref).



## Booster
The [`Booster`](@ref) object holds model data. They are created with training data. Internally
this is always a `DMatrix` but arguments will be automatically converted.
Expand Down Expand Up @@ -182,3 +183,43 @@ is equivalent to
bst = xgboost((X, y), num_round=10)
update!(bst, (X, y), num_round=10)
```

### Early Stopping
To help prevent overfitting to the training set, it is helpful to use a validation set to evaluate against to ensure that the XGBoost iterations continue to generalise outside training loss reduction. Early stopping provides a convenient way to automatically stop the
boosting process if it's observed that the generalisation capability of the model does not improve for `k` rounds.

If there is more than one element in watchlist, by default the last element will be used. In this case, you must use an ordered data structure (`OrderedDict`) compared to a standard unordered dictionary otherwise an exception will be generated. There will be
a warning if you want to execute early stopping mechanism (`early_stopping_rounds > 0`) but have provided a watchlist with type `Dict` with
more than 1 element.

Similarly, if there is more than one element in eval_metric, by default the last element will be used.

For example:

```julia
using LinearAlgebra
using OrderedCollections

𝒻(x) = 2norm(x)^2 - norm(x)

X = randn(100,3)
y = 𝒻.(eachrow(X))

dtrain = DMatrix((X, y))

X_valid = randn(50,3)
y_valid = 𝒻.(eachrow(X_valid))

dvalid = DMatrix((X_valid, y_valid))

bst = xgboost(dtrain, num_round = 100, eval_metric = "rmse", watchlist = OrderedDict(["train" => dtrain, "eval" => dvalid]), early_stopping_rounds = 5, max_depth=6, η=0.3)

# get the best iteration and use it for prediction
= predict(bst, X_valid, ntree_limit = bst.best_iteration)

using Statistics
println("RMSE from model prediction $(round((mean((ŷ - y_valid).^2).^0.5), digits = 8)).")

# we can also retain / use the best score (based on eval_metric) which is stored in the booster
println("Best RMSE from model training $(round((bst.best_score), digits = 8)).")
```
160 changes: 147 additions & 13 deletions src/booster.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@

"""
Booster
Expand Down Expand Up @@ -50,11 +49,17 @@ mutable struct Booster
# out what the hell is happening, it's never used for program logic
params::Dict{Symbol,Any}

function Booster(h::BoosterHandle, fsn::AbstractVector{<:AbstractString}=String[], params::AbstractDict=Dict())
# store early stopping information
best_iteration::Union{Int64, Missing}
best_score::Union{Float64, Missing}

function Booster(h::BoosterHandle, fsn::AbstractVector{<:AbstractString}=String[], params::AbstractDict=Dict(), best_iteration::Union{Int64, Missing}=missing,
best_score::Union{Float64, Missing}=missing)
finalizer(x -> xgbcall(XGBoosterFree, x.handle), new(h, fsn, params))
end
end


"""
setparam!(b::Booster, name, val)
Expand Down Expand Up @@ -366,7 +371,6 @@ function updateone!(b::Booster, Xy::DMatrix;
update_feature_names::Bool=false,
)
xgbcall(XGBoosterUpdateOneIter, b.handle, round_number, Xy.handle)
isempty(watchlist) || logeval(b, watchlist, round_number)
_maybe_update_feature_names!(b, Xy, update_feature_names)
b
end
Expand All @@ -382,7 +386,6 @@ function updateone!(b::Booster, Xy::DMatrix, g::AbstractVector{<:Real}, h::Abstr
g = convert(Vector{Cfloat}, g)
h = convert(Vector{Cfloat}, h)
xgbcall(XGBoosterBoostOneIter, b.handle, Xy.handle, g, h, length(g))
isempty(watchlist) || logeval(b, watchlist, round_number)
_maybe_update_feature_names!(b, Xy, update_feature_names)
b
end
Expand Down Expand Up @@ -422,14 +425,105 @@ Run `num_round` rounds of gradient boosting on [`Booster`](@ref) `b`.
The first and second derivatives of the loss function (`ℓ′` and `ℓ″` respectively) can be provided
for custom loss.
"""
function update!(b::Booster, data, a...; num_round::Integer=1, kw...)
function update!(b::Booster, data, a...;
num_round::Integer=1,
watchlist::Any = Dict("train" => data),
early_stopping_rounds::Integer=0,
maximize=false,
kw...,
)

if !isempty(watchlist) && early_stopping_rounds > 0
@info("Will train until there has been no improvement in $early_stopping_rounds rounds.\n")
best_round = 0
best_score = maximize ? -Inf : Inf
end

for j 1:num_round
round_number = getnrounds(b) + 1
updateone!(b, data, a...; round_number, kw...)

updateone!(b, data, a...; round_number, watchlist, kw...)

# Evaluate if watchlist is not empty
if !isempty(watchlist)
msg = evaliter(b, watchlist, round_number)
@info msg
if early_stopping_rounds > 0
score, dataset, metric = extract_metric_value(msg)
if (maximize && score > best_score || (!maximize && score < best_score))
best_score = score
best_round = j
elseif j - best_round >= early_stopping_rounds
@info(
"Xgboost: Stopping. \n\tBest iteration: $best_round. \n\tNo improvement in $dataset-$metric result in $early_stopping_rounds rounds."
)
# add additional fields to record the best iteration
b.best_iteration = best_round
b.best_score = best_score
return b
end
end
end
end
b
end



"""
extract_metric_value(msg, dataset=nothing, metric=nothing)
Extracts a numeric value from a message based on the specified dataset and metric.
If dataset or metric is not provided, the function will automatically find the last
mentioned dataset or metric in the message.
# Arguments
- `msg::AbstractString`: The message containing the numeric values.
- `dataset::Union{AbstractString, Nothing}`: The dataset to extract values for (default: `nothing`).
- `metric::Union{AbstractString, Nothing}`: The metric to extract values for (default: `nothing`).
# Returns
- Returns the parsed Float64 value if a match is found, otherwise returns `nothing`.
# Examples
```julia
msg = "train-rmsle:0.09516384803222511 train-rmse:0.12458323318968342 eval-rmsle:0.09311178520817574 eval-rmse:0.12088154560829874"
# Without specifying dataset and metric
value_without_params = extract_metric_value(msg)
println(value_without_params) # Output: (0.09311178520817574, "eval", "rmsle")
# With specifying dataset and metric
value_with_params = extract_metric_value(msg, "train", "rmsle")
println(value_with_params) # Output: (0.0951638480322251, "train", "rmsle")
"""

function extract_metric_value(msg, dataset=nothing, metric=nothing)
if isnothing(dataset)
# Find the last mentioned dataset - whilst retaining order
datasets = unique([m.match for m in eachmatch(r"\w+(?=-)", msg)])
dataset = last(collect(datasets))
end

if isnothing(metric)
# Find the last mentioned metric - whilst retaining order
metrics = unique([m.match for m in eachmatch(r"(?<=-)\w+", msg)])
metric = last(collect(metrics))
end

pattern = Regex("$dataset-$metric:([\\d.]+)")

match_result = match(pattern, msg)

if match_result != nothing
parsed_value = parse(Float64, match_result.captures[1])
return parsed_value, dataset, metric
end

# there was no match result - should error out
error("No match found for pattern: $dataset-$metric in message: $msg")
end

"""
xgboost(data; num_round=10, watchlist=Dict(), kw...)
xgboost(data, ℓ′, ℓ″; kw...)
Expand All @@ -439,7 +533,19 @@ This is essentially an alias for constructing a [`Booster`](@ref) with `data` an
followed by [`update!`](@ref) for `nrounds`.
`watchlist` is a dict the keys of which are strings giving the name of the data to watch
and the values of which are [`DMatrix`](@ref) objects containing the data.
and the values of which are [`DMatrix`](@ref) objects containing the data. It is mandatory to use an OrderedDict
when utilising early_stopping_rounds and there is more than 1 element in watchlist to ensure XGBoost uses the
correct and intended dataset to perform early stop.
`early_stopping_rounds` activates early stopping if set to > 0. Validation metric needs to improve at
least once in every k rounds. If `watchlist` is not explicitly provided, it will use the training dataset
to evaluate the stopping criterion. Otherwise, it will use the last data element in `watchlist` and the
last metric in `eval_metric` (if more than one). Note that `watchlist` cannot be empty if
`early_stopping_rounds` is enabled.
`maximize` If early_stopping_rounds is set, then this parameter must be set as well.
When it is false, it means the smaller the evaluation score the better. When set to true,
the larger the evaluation score the better.
All other keyword arguments are passed to [`Booster`](@ref). With few exceptions these are model
training hyper-parameters, see [here](https://xgboost.readthedocs.io/en/stable/parameter.html) for
Expand All @@ -450,23 +556,51 @@ See [`updateone!`](@ref) for more details.
## Examples
```julia
# Example 1: Basic usage of XGBoost
(X, y) = (randn(100,3), randn(100))
b = xgboost((X, y), 10, max_depth=10, η=0.1)
b = xgboost((X, y), num_round=10, max_depth=10, η=0.1)
ŷ = predict(b, X)
# Example 2: Using early stopping (using a validation set) with a watchlist
dtrain = DMatrix((randn(100,3), randn(100)))
dvalid = DMatrix((randn(100,3), randn(100)))
watchlist = OrderedDict(["train" => dtrain, "valid" => dvalid])
b = xgboost(dtrain, num_round=10, early_stopping_rounds = 2, watchlist = watchlist, max_depth=10, η=0.1)
# note that ntree_limit in the predict function helps assign the upper bound for iteration_range in the XGBoost API 1.4+
ŷ = predict(b, dvalid, ntree_limit = b.best_iteration)
```
"""
function xgboost(dm::DMatrix, a...;
num_round::Integer=10,
watchlist=Dict("train"=>dm),
kw...
)
num_round::Integer=10,
watchlist::AbstractDict = Dict("train" => dm),
early_stopping_rounds::Integer=0,
maximize=false,
kw...
)

Xy = DMatrix(dm)
b = Booster(Xy; kw...)

# We have a watchlist - give a warning if early stopping is provided and watchlist is a Dict type with length > 1
if isa(watchlist, Dict)
if early_stopping_rounds > 0 && length(watchlist) > 1
error("You must supply an OrderedDict type for watchlist if early stopping rounds is enabled and there is more than one element in watchlist.")
end
end

if isempty(watchlist) && early_stopping_rounds > 0
error("Watchlist must be supplied if early_stopping_rounds is enabled.")
end

isempty(watchlist) || @info("XGBoost: starting training.")
update!(b, Xy, a...; num_round, watchlist)
update!(b, Xy, a...; num_round, watchlist, early_stopping_rounds, maximize)
isempty(watchlist) || @info("Training rounds complete.")
b
end

xgboost(data, a...; kw...) = xgboost(DMatrix(data), a...; kw...)
Loading

0 comments on commit 4ead83f

Please sign in to comment.