Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Early stopping rounds #3

Merged
merged 8 commits into from
Oct 27, 2023
41 changes: 41 additions & 0 deletions docs/src/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,7 @@ Unlike feature data, label data can be extracted after construction of the `DMat
[`XGBoost.getlabel`](@ref).



## Booster
The [`Booster`](@ref) object holds model data. They are created with training data. Internally
this is always a `DMatrix` but arguments will be automatically converted.
Expand Down Expand Up @@ -182,3 +183,43 @@ is equivalent to
bst = xgboost((X, y), num_round=10)
update!(bst, (X, y), num_round=10)
```

### Early Stopping
To help prevent overfitting to the training set, it is helpful to use a validation set to evaluate against to ensure that the XGBoost iterations continue to generalise outside training loss reduction. Early stopping provides a convenient way to automatically stop the
boosting process if it's observed that the generalisation capability of the model does not improve for `k` rounds.

If there is more than one element in watchlist, by default the last element will be used. This makes it important to use an ordered data structure (`OrderedDict`) compared to a standard unordered dictionary; as you might not be guaranteed deterministic behaviour. There will be
a warning if you want to execute early stopping mechanism (`early_stopping_rounds > 0`) but have provided a watchlist with type `Dict` with
more than 1 element.

Similarly, if there is more than one element in eval_metric, by default the last element will be used.

For example:

```julia
using LinearAlgebra
using OrderedCollections

𝒻(x) = 2norm(x)^2 - norm(x)

X = randn(100,3)
y = 𝒻.(eachrow(X))

dtrain = DMatrix((X, y))

X_valid = randn(50,3)
y_valid = 𝒻.(eachrow(X_valid))

dvalid = DMatrix((X_valid, y_valid))

bst = xgboost(dtrain, num_round = 100, eval_metric = "rmse", watchlist = OrderedDict(["train" => dtrain, "eval" => dvalid]), early_stopping_rounds = 5, max_depth=6, η=0.3)

# get the best iteration and use it for prediction
ŷ = predict(bst, X_valid, ntree_limit = bst.best_iteration)

using Statistics
println("RMSE from model prediction $(round((mean((ŷ - y_valid).^2).^0.5), digits = 8)).")

# we can also retain / use the best score (based on eval_metric) which is stored in the booster
println("Best RMSE from model training $(round((bst.best_score), digits = 8)).")
```
98 changes: 73 additions & 25 deletions src/booster.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@

"""
Booster

Expand Down Expand Up @@ -50,11 +49,17 @@ mutable struct Booster
# out what the hell is happening, it's never used for program logic
params::Dict{Symbol,Any}

function Booster(h::BoosterHandle, fsn::AbstractVector{<:AbstractString}=String[], params::AbstractDict=Dict())
# store early stopping information
best_iteration::Union{Int64, Missing}
best_score::Union{Float64, Missing}

function Booster(h::BoosterHandle, fsn::AbstractVector{<:AbstractString}=String[], params::AbstractDict=Dict(), best_iteration::Union{Int64, Missing}=missing,
best_score::Union{Float64, Missing}=missing)
finalizer(x -> xgbcall(XGBoosterFree, x.handle), new(h, fsn, params))
end
end


"""
setparam!(b::Booster, name, val)

Expand Down Expand Up @@ -366,8 +371,14 @@ function updateone!(b::Booster, Xy::DMatrix;
update_feature_names::Bool=false,
)
xgbcall(XGBoosterUpdateOneIter, b.handle, round_number, Xy.handle)
isempty(watchlist) || (msg = evaliter(b, watchlist, round_number))
@info msg
# obtain the logs if watchlist is present (for early stopping and/or info)
if isempty(watchlist)
msg = nothing
else
msg = evaliter(b, watchlist, round_number)
@info msg
end
#isempty(watchlist) || (msg = evaliter(b, watchlist, round_number))
_maybe_update_feature_names!(b, Xy, update_feature_names)
b, msg
end
Expand All @@ -383,8 +394,14 @@ function updateone!(b::Booster, Xy::DMatrix, g::AbstractVector{<:Real}, h::Abstr
g = convert(Vector{Cfloat}, g)
h = convert(Vector{Cfloat}, h)
xgbcall(XGBoosterBoostOneIter, b.handle, Xy.handle, g, h, length(g))
isempty(watchlist) || (msg = evaliter(b, watchlist, round_number))
@info msg
# obtain the logs if watchlist is present (for early stopping and/or info)
if isempty(watchlist)
msg = nothing
else
msg = evaliter(b, watchlist, round_number)
@info msg
end
#isempty(watchlist) || (msg = evaliter(b, watchlist, round_number))
_maybe_update_feature_names!(b, Xy, update_feature_names)
b, msg
end
Expand Down Expand Up @@ -426,7 +443,7 @@ for custom loss.
"""
function update!(b::Booster, data, a...;
num_round::Integer=1,
watchlist=Dict("train"=>Xy),
watchlist::Any = Dict("train" => data),
early_stopping_rounds::Integer=0,
maximize=false,
kw...,
Expand All @@ -440,7 +457,8 @@ function update!(b::Booster, data, a...;

for j ∈ 1:num_round
round_number = getnrounds(b) + 1
b, msg = updateone!(b, data, a...; round_number, kw...)

b, msg = updateone!(b, data, a...; round_number, watchlist, kw...)
if !isempty(watchlist) && early_stopping_rounds > 0
score, dataset, metric = extract_metric_value(msg)
if (maximize && score > best_score || (!maximize && score < best_score))
Expand All @@ -450,7 +468,10 @@ function update!(b::Booster, data, a...;
@info(
"Xgboost: Stopping. \n\tBest iteration: $best_round. \n\tNo improvement in $dataset-$metric result in $early_stopping_rounds rounds."
)
return (b)
# add additional fields to record the best iteration
b.best_iteration = best_round
b.best_score = best_score
return b
end
end
end
Expand Down Expand Up @@ -489,14 +510,14 @@ println(value_with_params) # Output: (0.0951638480322251, "train", "rmsle")

function extract_metric_value(msg, dataset=nothing, metric=nothing)
if isnothing(dataset)
# Find the last mentioned dataset
datasets = Set([m.match for m in eachmatch(r"\w+(?=-)", msg)])
# Find the last mentioned dataset - whilst retaining order
datasets = unique([m.match for m in eachmatch(r"\w+(?=-)", msg)])
dataset = last(collect(datasets))
end

if isnothing(metric)
# Find the first mentioned metric
metrics = Set([m.match for m in eachmatch(r"(?<=-)\w+", msg)])
# Find the last mentioned metric - whilst retaining order
metrics = unique([m.match for m in eachmatch(r"(?<=-)\w+", msg)])
metric = last(collect(metrics))
end

Expand All @@ -513,8 +534,6 @@ function extract_metric_value(msg, dataset=nothing, metric=nothing)
end
end



"""
xgboost(data; num_round=10, watchlist=Dict(), kw...)
xgboost(data, ℓ′, ℓ″; kw...)
Expand All @@ -524,10 +543,13 @@ This is essentially an alias for constructing a [`Booster`](@ref) with `data` an
followed by [`update!`](@ref) for `nrounds`.

`watchlist` is a dict the keys of which are strings giving the name of the data to watch
and the values of which are [`DMatrix`](@ref) objects containing the data.
and the values of which are [`DMatrix`](@ref) objects containing the data. It is critical to use an OrderedDict
when utilising early_stopping_rounds to ensure XGBoost uses the correct and intended dataset to perform early stop.

`early_stopping_rounds` if 0, the early stopping function is not triggered. If set to a positive integer,
training with a validation set will stop if the performance doesn't improve for k rounds.
`early_stopping_rounds` activates early stopping if set to > 0. Validation metric needs to improve at
least once in every k rounds. If `watchlist` is not explicitly provided, it will use the training dataset
to evaluate the stopping criterion. Otherwise, it will use the last data element in `watchlist` and the
last metric in `eval_metric` (if more than one). Note that early stopping is ignored if `watchlist` is empty.

`maximize` If early_stopping_rounds is set, then this parameter must be set as well.
When it is false, it means the smaller the evaluation score the better. When set to true,
Expand All @@ -542,25 +564,51 @@ See [`updateone!`](@ref) for more details.

## Examples
```julia
# Example 1: Basic usage of XGBoost
(X, y) = (randn(100,3), randn(100))

b = xgboost((X, y), 10, max_depth=10, η=0.1)
b = xgboost((X, y), num_round=10, max_depth=10, η=0.1)

ŷ = predict(b, X)

# Example 2: Using early stopping (using a validation set) with a watchlist
dtrain = DMatrix((randn(100,3), randn(100)))
dvalid = DMatrix((randn(100,3), randn(100)))

watchlist = OrderedDict(["train" => dtrain, "valid" => dvalid])

b = xgboost(dtrain, num_round=10, early_stopping_rounds = 2, watchlist = watchlist, max_depth=10, η=0.1)

# note that ntree_limit in the predict function helps assign the upper bound for iteration_range in the XGBoost API 1.4+
ŷ = predict(b, dvalid, ntree_limit = b.best_iteration)
```
"""
function xgboost(dm::DMatrix, a...;
num_round::Integer=10,
watchlist=Dict("train"=>dm),
early_stopping_rounds::Integer=0,
maximize=false,
kw...
)
num_round::Integer=10,
watchlist::Any = Dict("train" => dm),
early_stopping_rounds::Integer=0,
maximize=false,
kw...
)

Xy = DMatrix(dm)
b = Booster(Xy; kw...)

# We have a watchlist - give a warning if early stopping is provided and watchlist is a Dict type with length > 1
if isa(watchlist, Dict)
if early_stopping_rounds > 0 && length(watchlist) > 1
@warn "Early stopping rounds activated whilst watchlist has more than 1 element. Recommended to provide watchlist as an OrderedDict to ensure deterministic behaviour."
end
end

if isempty(watchlist) && early_stopping_rounds > 0
@warn "Early stopping is ignored as provided watchlist is empty."
end

isempty(watchlist) || @info("XGBoost: starting training.")
update!(b, Xy, a...; num_round, watchlist, early_stopping_rounds, maximize)
isempty(watchlist) || @info("Training rounds complete.")
b
end

xgboost(data, a...; kw...) = xgboost(DMatrix(data), a...; kw...)
97 changes: 90 additions & 7 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ using CUDA: has_cuda, cu
import Term
using Random, SparseArrays
using Test
using OrderedCollections

include("utils.jl")

Expand Down Expand Up @@ -135,6 +136,7 @@ end

dtrain = XGBoost.load(DMatrix, testfilepath("agaricus.txt.train"), format=:libsvm)
dtest = XGBoost.load(DMatrix, testfilepath("agaricus.txt.test"), format=:libsvm)
# test the early stopping rounds interface with a Dict data type in the watchlist
watchlist = Dict("eval"=>dtest, "train"=>dtrain)

bst = xgboost(dtrain,
Expand All @@ -154,15 +156,96 @@ end
early_stopping_rounds = 2
)

nrounds_bst = XGBoost.getnrounds(bst)
nrounds_bst_early_stopping = XGBoost.getnrounds(bst_early_stopping)
# Check to see that running with early stopping results in less rounds
@test nrounds_bst_early_stopping < nrounds_bst
nrounds_bst = XGBoost.getnrounds(bst)
nrounds_bst_early_stopping = XGBoost.getnrounds(bst_early_stopping)
# Check to see that running with early stopping results in less than or equal rounds
@test nrounds_bst_early_stopping <= nrounds_bst

# Check number of rounds > early stopping rounds
@test nrounds_bst_early_stopping > 2
end
# Check number of rounds > early stopping rounds
@test nrounds_bst_early_stopping > 2

# test the early stopping rounds interface with an OrderedDict data type in the watchlist
watchlist_ordered = OrderedDict("train"=>dtrain, "eval"=>dtest)

bst_early_stopping = xgboost(dtrain,
num_round=30,
watchlist=watchlist_ordered,
η=1,
objective="binary:logistic",
eval_metric=["rmsle","rmse"],
early_stopping_rounds = 2
)

@test XGBoost.getnrounds(bst_early_stopping) > 2
@test XGBoost.getnrounds(bst_early_stopping) <= nrounds_bst

# get the rmse difference for the dtest
ŷ = predict(bst_early_stopping, dtest, ntree_limit = bst_early_stopping.best_iteration)

filename = "agaricus.txt.test"
lines = readlines(testfilepath(filename))
y = [parse(Float64,split(s)[1]) for s in lines]

function calc_rmse(y_true::Vector{T}, y_pred::Vector{T}) where T <: Float64
return sqrt(sum((y_true .- y_pred).^2)/length(y_true))
end

calc_metric = calc_rmse(Float64.(y), Float64.(ŷ))

# ensure that the results are the same (as numerically possible) with the best round
@test abs(bst_early_stopping.best_score - calc_metric) < 1e-9

# test the early stopping rounds interface with an OrderedDict data type in the watchlist using num_parallel_tree parameter
# this will test the XGBoost API for iteration_range is being utilised properly
watchlist_ordered = OrderedDict("train"=>dtrain, "eval"=>dtest)

bst_early_stopping = xgboost(dtrain,
num_round=30,
watchlist=watchlist_ordered,
η=1,
objective="binary:logistic",
eval_metric=["rmsle","rmse"],
early_stopping_rounds = 2,
num_parallel_tree = 10,
colsample_bylevel = 0.5
)

@test XGBoost.getnrounds(bst_early_stopping) > 2
@test XGBoost.getnrounds(bst_early_stopping) <= nrounds_bst

# get the rmse difference for the dtest
ŷ = predict(bst_early_stopping, dtest, ntree_limit = bst_early_stopping.best_iteration)
calc_metric = calc_rmse(Float64.(y), Float64.(ŷ))

# ensure that the results are the same (as numerically possible) with the best round
@test abs(bst_early_stopping.best_score - calc_metric) < 1e-9

# test the interface with no watchlist provided (defaults to the training dataset)
bst_early_stopping = xgboost(dtrain,
num_round=30,
η=1,
objective="binary:logistic",
eval_metric=["rmsle","rmse"],
early_stopping_rounds = 2
)

@test XGBoost.getnrounds(bst_early_stopping) > 2
@test XGBoost.getnrounds(bst_early_stopping) <= nrounds_bst

# test the interface with an empty watchlist (no output)
# this should trigger no early stopping rounds
bst_empty_watchlist = xgboost(dtrain,
num_round=30,
η=1,
watchlist = Dict(),
objective="binary:logistic",
eval_metric=["rmsle","rmse"],
early_stopping_rounds = 2 # this should be ignored
)

@test XGBoost.getnrounds(bst_empty_watchlist) == nrounds_bst

end


@testset "Blobs training" begin
Expand Down
Loading