From 14fe274c09a9cf01f4384f575c6d354836a9caa4 Mon Sep 17 00:00:00 2001 From: David Sun Date: Mon, 9 Oct 2023 01:59:21 +1100 Subject: [PATCH 01/15] add functionality for early stopping --- src/booster.jl | 49 +++++++++++++++++++++++++++++++++++++++++++++--- test/runtests.jl | 35 ++++++++++++++++++++++++++++++++++ 2 files changed, 81 insertions(+), 3 deletions(-) diff --git a/src/booster.jl b/src/booster.jl index bebb054..c38d69a 100644 --- a/src/booster.jl +++ b/src/booster.jl @@ -366,7 +366,6 @@ function updateone!(b::Booster, Xy::DMatrix; update_feature_names::Bool=false, ) xgbcall(XGBoosterUpdateOneIter, b.handle, round_number, Xy.handle) - isempty(watchlist) || logeval(b, watchlist, round_number) _maybe_update_feature_names!(b, Xy, update_feature_names) b end @@ -422,10 +421,44 @@ Run `num_round` rounds of gradient boosting on [`Booster`](@ref) `b`. The first and second derivatives of the loss function (`ℓ′` and `ℓ″` respectively) can be provided for custom loss. """ -function update!(b::Booster, data, a...; num_round::Integer=1, kw...) +function update!(b::Booster, data, a...; + num_round::Integer=1, + watchlist=Dict("train"=>Xy), + early_stopping_rounds::Integer=0, + maximize=false, + kw..., + ) + + if !isempty(watchlist) && early_stopping_rounds > 0 + early_stopping_set = collect(Iterators.map(string, keys(watchlist)))[end] + println("Will train until there has been no improvement in $early_stopping_rounds rounds.\n") + best_round = 0 + best_score = maximize ? -Inf : Inf + end + for j ∈ 1:num_round round_number = getnrounds(b) + 1 updateone!(b, data, a...; round_number, kw...) + if !isempty(watchlist) + msg = evaliter(b, watchlist, j) + @info(msg) + if early_stopping_rounds > 0 + split_msg = split(msg, r"\s+|:") + metric_name_idx = findfirst(occursin.(early_stopping_set, split_msg)) + score = parse(Float64, split_msg[metric_name_idx + 1]) + + if (maximize && score > best_score || (!maximize && score < best_score)) + best_score = score + best_round = j + elseif j - best_round >= early_stopping_rounds + watchlist_metric = (split_msg[metric_name_idx]) + @info( + "Xgboost: Stopping. \n\tBest iteration: $best_round. \n\t$(watchlist_metric) has not improved in $early_stopping_rounds rounds." + ) + return (b) + end + end + end end b end @@ -441,6 +474,13 @@ followed by [`update!`](@ref) for `nrounds`. `watchlist` is a dict the keys of which are strings giving the name of the data to watch and the values of which are [`DMatrix`](@ref) objects containing the data. +`early_stopping_rounds` if 0, the early stopping function is not triggered. If set to a positive integer, +training with a validation set will stop if the performance doesn't improve for k rounds. + +`maximize` If early_stopping_rounds is set, then this parameter must be set as well. +When it is false, it means the smaller the evaluation score the better. When set to true, +the larger the evaluation score the better. + All other keyword arguments are passed to [`Booster`](@ref). With few exceptions these are model training hyper-parameters, see [here](https://xgboost.readthedocs.io/en/stable/parameter.html) for a comprehensive list. @@ -460,12 +500,15 @@ ŷ = predict(b, X) function xgboost(dm::DMatrix, a...; num_round::Integer=10, watchlist=Dict("train"=>dm), + early_stopping_rounds::Integer=0, + maximize=false, kw... ) + println("Running DEV Version: 5") Xy = DMatrix(dm) b = Booster(Xy; kw...) isempty(watchlist) || @info("XGBoost: starting training.") - update!(b, Xy, a...; num_round, watchlist) + update!(b, Xy, a...; num_round, watchlist, early_stopping_rounds, maximize) isempty(watchlist) || @info("Training rounds complete.") b end diff --git a/test/runtests.jl b/test/runtests.jl index cade6d5..39dd5cf 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -130,6 +130,41 @@ end end end + +@testset "Early Stopping rounds" begin + + dtrain = XGBoost.load(DMatrix, testfilepath("agaricus.txt.train"), format=:libsvm) + dtest = XGBoost.load(DMatrix, testfilepath("agaricus.txt.test"), format=:libsvm) + watchlist = Dict("eval"=>dtest, "train"=>dtrain) + + bst = xgboost(dtrain, + num_round=30, + watchlist=watchlist, + η=1, + objective="binary:logistic", + eval_metric=["rmsle","rmse"] + ) + + bst_early_stopping = xgboost(dtrain, + num_round=30, + watchlist=watchlist, + η=1, + objective="binary:logistic", + eval_metric=["rmsle","rmse"], + early_stopping_rounds = 2 + ) + + nrounds_bst = XGBoost.getnrounds(bst) + nrounds_bst_early_stopping = XGBoost.getnrounds(bst_early_stopping) + # Check to see that running with early stopping results in less rounds + @test nrounds_bst_early_stopping < nrounds_bst + + # Check number of rounds > early stopping rounds + @test nrounds_bst_early_stopping > 2 +end + + + @testset "Blobs training" begin (X, y) = load_classification() From f70ee6d99a699730d78d99e8fa976e9b570901e4 Mon Sep 17 00:00:00 2001 From: David Sun Date: Mon, 9 Oct 2023 02:00:37 +1100 Subject: [PATCH 02/15] remove version word --- src/booster.jl | 1 - 1 file changed, 1 deletion(-) diff --git a/src/booster.jl b/src/booster.jl index c38d69a..2769b89 100644 --- a/src/booster.jl +++ b/src/booster.jl @@ -504,7 +504,6 @@ function xgboost(dm::DMatrix, a...; maximize=false, kw... ) - println("Running DEV Version: 5") Xy = DMatrix(dm) b = Booster(Xy; kw...) isempty(watchlist) || @info("XGBoost: starting training.") From 5c9babaaff1c5a095d07776a6c9790802070a728 Mon Sep 17 00:00:00 2001 From: David Sun Date: Wed, 11 Oct 2023 23:02:48 +1100 Subject: [PATCH 03/15] evaluation msg into a parsing function and add back evaluation to updateone --- src/booster.jl | 100 +++++++++++++++++++++++++++++++++++++------------ 1 file changed, 76 insertions(+), 24 deletions(-) diff --git a/src/booster.jl b/src/booster.jl index 2769b89..96d1814 100644 --- a/src/booster.jl +++ b/src/booster.jl @@ -366,8 +366,10 @@ function updateone!(b::Booster, Xy::DMatrix; update_feature_names::Bool=false, ) xgbcall(XGBoosterUpdateOneIter, b.handle, round_number, Xy.handle) + isempty(watchlist) || (msg = evaliter(b, watchlist, round_number)) + @info msg _maybe_update_feature_names!(b, Xy, update_feature_names) - b + b, msg end function updateone!(b::Booster, Xy::DMatrix, g::AbstractVector{<:Real}, h::AbstractVector{<:Real}; @@ -381,9 +383,10 @@ function updateone!(b::Booster, Xy::DMatrix, g::AbstractVector{<:Real}, h::Abstr g = convert(Vector{Cfloat}, g) h = convert(Vector{Cfloat}, h) xgbcall(XGBoosterBoostOneIter, b.handle, Xy.handle, g, h, length(g)) - isempty(watchlist) || logeval(b, watchlist, round_number) + isempty(watchlist) || (msg = evaliter(b, watchlist, round_number)) + @info msg _maybe_update_feature_names!(b, Xy, update_feature_names) - b + b, msg end """ @@ -430,39 +433,88 @@ function update!(b::Booster, data, a...; ) if !isempty(watchlist) && early_stopping_rounds > 0 - early_stopping_set = collect(Iterators.map(string, keys(watchlist)))[end] - println("Will train until there has been no improvement in $early_stopping_rounds rounds.\n") + @info("Will train until there has been no improvement in $early_stopping_rounds rounds.\n") best_round = 0 best_score = maximize ? -Inf : Inf end for j ∈ 1:num_round round_number = getnrounds(b) + 1 - updateone!(b, data, a...; round_number, kw...) - if !isempty(watchlist) - msg = evaliter(b, watchlist, j) - @info(msg) - if early_stopping_rounds > 0 - split_msg = split(msg, r"\s+|:") - metric_name_idx = findfirst(occursin.(early_stopping_set, split_msg)) - score = parse(Float64, split_msg[metric_name_idx + 1]) - - if (maximize && score > best_score || (!maximize && score < best_score)) - best_score = score - best_round = j - elseif j - best_round >= early_stopping_rounds - watchlist_metric = (split_msg[metric_name_idx]) - @info( - "Xgboost: Stopping. \n\tBest iteration: $best_round. \n\t$(watchlist_metric) has not improved in $early_stopping_rounds rounds." - ) - return (b) - end + b, msg = updateone!(b, data, a...; round_number, kw...) + if !isempty(watchlist) && early_stopping_rounds > 0 + score, dataset, metric = extract_metric_value(msg) + if (maximize && score > best_score || (!maximize && score < best_score)) + best_score = score + best_round = j + elseif j - best_round >= early_stopping_rounds + @info( + "Xgboost: Stopping. \n\tBest iteration: $best_round. \n\tNo improvement in $dataset-$metric result in $early_stopping_rounds rounds." + ) + return (b) end end end b end + + +""" + extract_metric_value(msg, dataset=nothing, metric=nothing) + +Extracts a numeric value from a message based on the specified dataset and metric. +If dataset or metric is not provided, the function will automatically find the last +mentioned dataset or metric in the message. + +# Arguments +- `msg::AbstractString`: The message containing the numeric values. +- `dataset::Union{AbstractString, Nothing}`: The dataset to extract values for (default: `nothing`). +- `metric::Union{AbstractString, Nothing}`: The metric to extract values for (default: `nothing`). + +# Returns +- Returns the parsed Float64 value if a match is found, otherwise returns `nothing`. + +# Examples +```julia +msg = "train-rmsle:0.09516384803222511 train-rmse:0.12458323318968342 eval-rmsle:0.09311178520817574 eval-rmse:0.12088154560829874" + +# Without specifying dataset and metric +value_without_params = extract_metric_value(msg) +println(value_without_params) # Output: (0.09311178520817574, "eval", "rmsle") + +# With specifying dataset and metric +value_with_params = extract_metric_value(msg, "train", "rmsle") +println(value_with_params) # Output: (0.0951638480322251, "train", "rmsle") +""" + +function extract_metric_value(msg, dataset=nothing, metric=nothing) + if isnothing(dataset) + # Find the last mentioned dataset + datasets = Set([m.match for m in eachmatch(r"\w+(?=-)", msg)]) + dataset = last(collect(datasets)) + end + + if isnothing(metric) + # Find the first mentioned metric + metrics = Set([m.match for m in eachmatch(r"(?<=-)\w+", msg)]) + metric = last(collect(metrics)) + end + + pattern = Regex("$dataset-$metric:([\\d.]+)") + + match_result = match(pattern, msg) + + if match_result != nothing + parsed_value = parse(Float64, match_result.captures[1]) + return parsed_value, dataset, metric + else + @warn "No match found for pattern: $dataset-$metric in message: $msg" + return nothing + end +end + + + """ xgboost(data; num_round=10, watchlist=Dict(), kw...) xgboost(data, ℓ′, ℓ″; kw...) From 3f7fef4da3938b0ee568c458bb90dcda2698e23d Mon Sep 17 00:00:00 2001 From: Wilan Wong Date: Wed, 25 Oct 2023 14:51:22 +1100 Subject: [PATCH 04/15] Updated the call to updateone! to pass in the watchlist so it can be used by early stopping round logic. --- src/booster.jl | 4 ++-- test/runtests.jl | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/booster.jl b/src/booster.jl index 96d1814..ca5fad4 100644 --- a/src/booster.jl +++ b/src/booster.jl @@ -440,7 +440,7 @@ function update!(b::Booster, data, a...; for j ∈ 1:num_round round_number = getnrounds(b) + 1 - b, msg = updateone!(b, data, a...; round_number, kw...) + b, msg = updateone!(b, data, a...; round_number, watchlist, kw...) if !isempty(watchlist) && early_stopping_rounds > 0 score, dataset, metric = extract_metric_value(msg) if (maximize && score > best_score || (!maximize && score < best_score)) @@ -495,7 +495,7 @@ function extract_metric_value(msg, dataset=nothing, metric=nothing) end if isnothing(metric) - # Find the first mentioned metric + # Find the last mentioned metric metrics = Set([m.match for m in eachmatch(r"(?<=-)\w+", msg)]) metric = last(collect(metrics)) end diff --git a/test/runtests.jl b/test/runtests.jl index 39dd5cf..7ffbd5a 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -156,8 +156,8 @@ end nrounds_bst = XGBoost.getnrounds(bst) nrounds_bst_early_stopping = XGBoost.getnrounds(bst_early_stopping) - # Check to see that running with early stopping results in less rounds - @test nrounds_bst_early_stopping < nrounds_bst + # Check to see that running with early stopping results in less than or equal rounds + @test nrounds_bst_early_stopping <= nrounds_bst # Check number of rounds > early stopping rounds @test nrounds_bst_early_stopping > 2 From 154ca63801316026d6ac88b44dd8b50cc22ad1f1 Mon Sep 17 00:00:00 2001 From: Wilan Wong Date: Thu, 26 Oct 2023 22:16:57 +1100 Subject: [PATCH 05/15] Added comments, additional examples, fixed issues with watchlist ordering as a Dict. --- src/booster.jl | 78 ++++++++++++++++++++++++++++++++++-------------- test/runtests.jl | 43 +++++++++++++++++++++++++- 2 files changed, 98 insertions(+), 23 deletions(-) diff --git a/src/booster.jl b/src/booster.jl index ca5fad4..d8beced 100644 --- a/src/booster.jl +++ b/src/booster.jl @@ -1,4 +1,3 @@ - """ Booster @@ -366,8 +365,14 @@ function updateone!(b::Booster, Xy::DMatrix; update_feature_names::Bool=false, ) xgbcall(XGBoosterUpdateOneIter, b.handle, round_number, Xy.handle) - isempty(watchlist) || (msg = evaliter(b, watchlist, round_number)) - @info msg + # obtain the logs if watchlist is present (for early stopping and/or info) + if isempty(watchlist) + msg = nothing + else + msg = evaliter(b, watchlist, round_number) + @info msg + end + #isempty(watchlist) || (msg = evaliter(b, watchlist, round_number)) _maybe_update_feature_names!(b, Xy, update_feature_names) b, msg end @@ -383,8 +388,14 @@ function updateone!(b::Booster, Xy::DMatrix, g::AbstractVector{<:Real}, h::Abstr g = convert(Vector{Cfloat}, g) h = convert(Vector{Cfloat}, h) xgbcall(XGBoosterBoostOneIter, b.handle, Xy.handle, g, h, length(g)) - isempty(watchlist) || (msg = evaliter(b, watchlist, round_number)) - @info msg + # obtain the logs if watchlist is present (for early stopping and/or info) + if isempty(watchlist) + msg = nothing + else + msg = evaliter(b, watchlist, round_number) + @info msg + end + #isempty(watchlist) || (msg = evaliter(b, watchlist, round_number)) _maybe_update_feature_names!(b, Xy, update_feature_names) b, msg end @@ -426,7 +437,7 @@ for custom loss. """ function update!(b::Booster, data, a...; num_round::Integer=1, - watchlist=Dict("train"=>Xy), + watchlist::Any = Dict("train" => data), early_stopping_rounds::Integer=0, maximize=false, kw..., @@ -440,6 +451,7 @@ function update!(b::Booster, data, a...; for j ∈ 1:num_round round_number = getnrounds(b) + 1 + b, msg = updateone!(b, data, a...; round_number, watchlist, kw...) if !isempty(watchlist) && early_stopping_rounds > 0 score, dataset, metric = extract_metric_value(msg) @@ -489,14 +501,14 @@ println(value_with_params) # Output: (0.0951638480322251, "train", "rmsle") function extract_metric_value(msg, dataset=nothing, metric=nothing) if isnothing(dataset) - # Find the last mentioned dataset - datasets = Set([m.match for m in eachmatch(r"\w+(?=-)", msg)]) + # Find the last mentioned dataset - whilst retaining order + datasets = unique([m.match for m in eachmatch(r"\w+(?=-)", msg)]) dataset = last(collect(datasets)) end if isnothing(metric) - # Find the last mentioned metric - metrics = Set([m.match for m in eachmatch(r"(?<=-)\w+", msg)]) + # Find the last mentioned metric - whilst retaining order + metrics = unique([m.match for m in eachmatch(r"(?<=-)\w+", msg)]) metric = last(collect(metrics)) end @@ -513,8 +525,6 @@ function extract_metric_value(msg, dataset=nothing, metric=nothing) end end - - """ xgboost(data; num_round=10, watchlist=Dict(), kw...) xgboost(data, ℓ′, ℓ″; kw...) @@ -524,10 +534,13 @@ This is essentially an alias for constructing a [`Booster`](@ref) with `data` an followed by [`update!`](@ref) for `nrounds`. `watchlist` is a dict the keys of which are strings giving the name of the data to watch -and the values of which are [`DMatrix`](@ref) objects containing the data. +and the values of which are [`DMatrix`](@ref) objects containing the data. It is critical to use an OrderedDict +when utilising early_stopping_rounds to ensure XGBoost uses the correct and intended dataset to perform early stop. -`early_stopping_rounds` if 0, the early stopping function is not triggered. If set to a positive integer, -training with a validation set will stop if the performance doesn't improve for k rounds. +`early_stopping_rounds` activates early stopping if set to > 0. Validation metric needs to improve at +least once in every k rounds. If `watchlist` is not explicitly provided, it will use the training dataset +to evaluate the stopping criterion. Otherwise, it will use the last data element in `watchlist` and the +last metric in `eval_metric` (if more than one). `maximize` If early_stopping_rounds is set, then this parameter must be set as well. When it is false, it means the smaller the evaluation score the better. When set to true, @@ -542,25 +555,46 @@ See [`updateone!`](@ref) for more details. ## Examples ```julia +# Example 1: Basic usage of XGBoost (X, y) = (randn(100,3), randn(100)) -b = xgboost((X, y), 10, max_depth=10, η=0.1) +b = xgboost((X, y), num_round=10, max_depth=10, η=0.1) + +ŷ = predict(b, X) + +# Example 2: Using early stopping (using a validation set) with a watchlist +dtrain = DMatrix((randn(100,3), randn(100))) +dvalid = DMatrix((randn(100,3), randn(100))) + +watchlist = OrderedDict(["train" => dtrain, "valid" => dvalid]) + +b = xgboost(dtrain, num_round=10, early_stopping_rounds = 2, watchlist = watchlist, max_depth=10, η=0.1) ŷ = predict(b, X) ``` """ function xgboost(dm::DMatrix, a...; - num_round::Integer=10, - watchlist=Dict("train"=>dm), - early_stopping_rounds::Integer=0, - maximize=false, - kw... - ) + num_round::Integer=10, + watchlist::Any = Dict("train" => dm), + early_stopping_rounds::Integer=0, + maximize=false, + kw... + ) + Xy = DMatrix(dm) b = Booster(Xy; kw...) + + # We have a watchlist - give a warning if early stopping is provided and watchlist is a Dict type with length > 1 + if isa(watchlist, Dict) + if early_stopping_rounds > 0 && length(watchlist) > 1 + @warn "Early stopping rounds activated whilst watchlist has more than 1 element. Recommended to provide watchlist as an OrderedDict to ensure deterministic behaviour." + end + end + isempty(watchlist) || @info("XGBoost: starting training.") update!(b, Xy, a...; num_round, watchlist, early_stopping_rounds, maximize) isempty(watchlist) || @info("Training rounds complete.") b end + xgboost(data, a...; kw...) = xgboost(DMatrix(data), a...; kw...) diff --git a/test/runtests.jl b/test/runtests.jl index 7ffbd5a..81fd0c1 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -6,6 +6,7 @@ using Test include("utils.jl") + @testset "XGBoost" begin # note that non-Float32 matrices will get truncated and `==` may not hold @@ -135,6 +136,7 @@ end dtrain = XGBoost.load(DMatrix, testfilepath("agaricus.txt.train"), format=:libsvm) dtest = XGBoost.load(DMatrix, testfilepath("agaricus.txt.test"), format=:libsvm) + # test the early stopping rounds interface with a Dict data type in the watchlist watchlist = Dict("eval"=>dtest, "train"=>dtrain) bst = xgboost(dtrain, @@ -161,8 +163,47 @@ end # Check number of rounds > early stopping rounds @test nrounds_bst_early_stopping > 2 -end + # test the early stopping rounds interface with an OrderedDict data type in the watchlist + watchlist_ordered = OrderedDict("eval"=>dtest, "train"=>dtrain) + + bst_early_stopping = xgboost(dtrain, + num_round=30, + watchlist=watchlist_ordered, + η=1, + objective="binary:logistic", + eval_metric=["rmsle","rmse"], + early_stopping_rounds = 2 + ) + + @test XGBoost.getnrounds(bst_early_stopping) > 2 + @test XGBoost.getnrounds(bst_early_stopping) <= nrounds_bst + + # test the interface with no watchlist provided (defaults to the training dataset) + bst_early_stopping = xgboost(dtrain, + num_round=30, + η=1, + objective="binary:logistic", + eval_metric=["rmsle","rmse"], + early_stopping_rounds = 2 + ) + + @test XGBoost.getnrounds(bst_early_stopping) > 2 + @test XGBoost.getnrounds(bst_early_stopping) <= nrounds_bst + + # test the interface with an empty watchlist (no output) + # this should trigger no early stopping rounds + bst_empty_watchlist = xgboost(dtrain, + num_round=30, + η=1, + watchlist = Dict(), + objective="binary:logistic", + eval_metric=["rmsle","rmse"], + early_stopping_rounds = 2 # this should be ignored + ) + + @test XGBoost.getnrounds(bst_empty_watchlist) == nrounds_bst +end @testset "Blobs training" begin From 169d563b073c752eb2b6303b3c24fc0c84e3fece Mon Sep 17 00:00:00 2001 From: Wilan Wong Date: Fri, 27 Oct 2023 00:55:20 +1100 Subject: [PATCH 06/15] Added functionality to extract the best iteration round with examples. Included additional test case coverage. --- src/booster.jl | 22 +++++++++++++++---- test/runtests.jl | 56 ++++++++++++++++++++++++++++++++++++++++++++++-- 2 files changed, 72 insertions(+), 6 deletions(-) diff --git a/src/booster.jl b/src/booster.jl index d8beced..5ed07ec 100644 --- a/src/booster.jl +++ b/src/booster.jl @@ -49,11 +49,17 @@ mutable struct Booster # out what the hell is happening, it's never used for program logic params::Dict{Symbol,Any} - function Booster(h::BoosterHandle, fsn::AbstractVector{<:AbstractString}=String[], params::AbstractDict=Dict()) + # store early stopping information + best_iteration::Union{Int64, Missing} + best_score::Union{Float64, Missing} + + function Booster(h::BoosterHandle, fsn::AbstractVector{<:AbstractString}=String[], params::AbstractDict=Dict(), best_iteration::Union{Int64, Missing}=missing, + best_score::Union{Float64, Missing}=missing) finalizer(x -> xgbcall(XGBoosterFree, x.handle), new(h, fsn, params)) end end + """ setparam!(b::Booster, name, val) @@ -462,7 +468,10 @@ function update!(b::Booster, data, a...; @info( "Xgboost: Stopping. \n\tBest iteration: $best_round. \n\tNo improvement in $dataset-$metric result in $early_stopping_rounds rounds." ) - return (b) + # add additional fields to record the best iteration + b.best_iteration = best_round + b.best_score = best_score + return b end end end @@ -540,7 +549,7 @@ when utilising early_stopping_rounds to ensure XGBoost uses the correct and inte `early_stopping_rounds` activates early stopping if set to > 0. Validation metric needs to improve at least once in every k rounds. If `watchlist` is not explicitly provided, it will use the training dataset to evaluate the stopping criterion. Otherwise, it will use the last data element in `watchlist` and the -last metric in `eval_metric` (if more than one). +last metric in `eval_metric` (if more than one). Note that early stopping is ignored if `watchlist` is empty. `maximize` If early_stopping_rounds is set, then this parameter must be set as well. When it is false, it means the smaller the evaluation score the better. When set to true, @@ -570,7 +579,8 @@ watchlist = OrderedDict(["train" => dtrain, "valid" => dvalid]) b = xgboost(dtrain, num_round=10, early_stopping_rounds = 2, watchlist = watchlist, max_depth=10, η=0.1) -ŷ = predict(b, X) +# note that ntree_limit in the predict function helps assign the upper bound for iteration_range in the XGBoost API 1.4+ +ŷ = predict(b, dvalid, ntree_limit = b.best_iteration) ``` """ function xgboost(dm::DMatrix, a...; @@ -590,6 +600,10 @@ function xgboost(dm::DMatrix, a...; @warn "Early stopping rounds activated whilst watchlist has more than 1 element. Recommended to provide watchlist as an OrderedDict to ensure deterministic behaviour." end end + + if isempty(watchlist) && early_stopping_rounds > 0 + @warn "Early stopping is ignored as provided watchlist is empty." + end isempty(watchlist) || @info("XGBoost: starting training.") update!(b, Xy, a...; num_round, watchlist, early_stopping_rounds, maximize) diff --git a/test/runtests.jl b/test/runtests.jl index 81fd0c1..4f081db 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -6,7 +6,6 @@ using Test include("utils.jl") - @testset "XGBoost" begin # note that non-Float32 matrices will get truncated and `==` may not hold @@ -165,7 +164,7 @@ end @test nrounds_bst_early_stopping > 2 # test the early stopping rounds interface with an OrderedDict data type in the watchlist - watchlist_ordered = OrderedDict("eval"=>dtest, "train"=>dtrain) + watchlist_ordered = OrderedDict("train"=>dtrain, "eval"=>dtest) bst_early_stopping = xgboost(dtrain, num_round=30, @@ -179,6 +178,47 @@ end @test XGBoost.getnrounds(bst_early_stopping) > 2 @test XGBoost.getnrounds(bst_early_stopping) <= nrounds_bst + # get the rmse difference for the dtest + ŷ = predict(bst_early_stopping, dtest, ntree_limit = bst_early_stopping.best_iteration) + + filename = "agaricus.txt.test" + lines = readlines(testfilepath(filename)) + y = [parse(Float64,split(s)[1]) for s in lines] + + function calc_rmse(y_true::Vector{T}, y_pred::Vector{T}) where T <: Float64 + return sqrt(sum((y_true .- y_pred).^2)/length(y_true)) + end + + calc_metric = calc_rmse(Float64.(y), Float64.(ŷ)) + + # ensure that the results are the same (as numerically possible) with the best round + @test abs(bst_early_stopping.best_score - calc_metric) < 1e-9 + + # test the early stopping rounds interface with an OrderedDict data type in the watchlist using num_parallel_tree parameter + # this will test the XGBoost API for iteration_range is being utilised properly + watchlist_ordered = OrderedDict("train"=>dtrain, "eval"=>dtest) + + bst_early_stopping = xgboost(dtrain, + num_round=30, + watchlist=watchlist_ordered, + η=1, + objective="binary:logistic", + eval_metric=["rmsle","rmse"], + early_stopping_rounds = 2, + num_parallel_tree = 10, + colsample_bylevel = 0.5 + ) + + @test XGBoost.getnrounds(bst_early_stopping) > 2 + @test XGBoost.getnrounds(bst_early_stopping) <= nrounds_bst + + # get the rmse difference for the dtest + ŷ = predict(bst_early_stopping, dtest, ntree_limit = bst_early_stopping.best_iteration) + calc_metric = calc_rmse(Float64.(y), Float64.(ŷ)) + + # ensure that the results are the same (as numerically possible) with the best round + @test abs(bst_early_stopping.best_score - calc_metric) < 1e-9 + # test the interface with no watchlist provided (defaults to the training dataset) bst_early_stopping = xgboost(dtrain, num_round=30, @@ -203,6 +243,18 @@ end ) @test XGBoost.getnrounds(bst_empty_watchlist) == nrounds_bst + + # test the functionality of utilising the best model iteration + bst_early_stopping = xgboost(dtrain, + num_round=30, + η=1, + objective="binary:logistic", + eval_metric=["rmsle","rmse"], + early_stopping_rounds = 2 + ) + + + end From b564be5f9e15ed66b872b854db0fe659ae68ea76 Mon Sep 17 00:00:00 2001 From: Wilan Wong Date: Fri, 27 Oct 2023 00:59:09 +1100 Subject: [PATCH 07/15] Cleaned up some lingering test cases. --- test/runtests.jl | 11 ----------- 1 file changed, 11 deletions(-) diff --git a/test/runtests.jl b/test/runtests.jl index 4f081db..d4337c3 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -244,17 +244,6 @@ end @test XGBoost.getnrounds(bst_empty_watchlist) == nrounds_bst - # test the functionality of utilising the best model iteration - bst_early_stopping = xgboost(dtrain, - num_round=30, - η=1, - objective="binary:logistic", - eval_metric=["rmsle","rmse"], - early_stopping_rounds = 2 - ) - - - end From 7af7f750e1cf53ccf46dfeaa73d214c328bf3902 Mon Sep 17 00:00:00 2001 From: Wilan Wong Date: Fri, 27 Oct 2023 12:09:47 +1100 Subject: [PATCH 08/15] Updated doc to include early stopping example. --- docs/src/index.md | 35 +++++++++++++++++++++++++++++++++++ test/runtests.jl | 20 ++++++++++---------- 2 files changed, 45 insertions(+), 10 deletions(-) diff --git a/docs/src/index.md b/docs/src/index.md index 3e10ce8..d743874 100644 --- a/docs/src/index.md +++ b/docs/src/index.md @@ -127,6 +127,7 @@ Unlike feature data, label data can be extracted after construction of the `DMat [`XGBoost.getlabel`](@ref). + ## Booster The [`Booster`](@ref) object holds model data. They are created with training data. Internally this is always a `DMatrix` but arguments will be automatically converted. @@ -182,3 +183,37 @@ is equivalent to bst = xgboost((X, y), num_round=10) update!(bst, (X, y), num_round=10) ``` + +### Early Stopping +To help prevent overfitting to the training set, it is helpful to use a validation set to evaluate against to ensure that the XGBoost iterations continue to generalise outside training loss reduction. Early stopping provides a convenient way to automatically stop the +boosting process if it's observed that the generalisation capability of the model does not improve for `k` rounds. + +For example: + +```julia +using LinearAlgebra +using OrderedCollections + +𝒻(x) = 2norm(x)^2 - norm(x) + +X = randn(100,3) +y = 𝒻.(eachrow(X)) + +dtrain = DMatrix((X, y)) + +X_valid = randn(50,3) +y_valid = 𝒻.(eachrow(X_valid)) + +dvalid = DMatrix((X_valid, y_valid)) + +bst = xgboost(dtrain, num_round = 100, eval_metric = "rmse", watchlist = OrderedDict(["train" => dtrain, "eval" => dvalid]), early_stopping_rounds = 5, max_depth=6, η=0.3) + +# get the best iteration and use it for prediction +ŷ = predict(bst, X_valid, ntree_limit = bst.best_iteration) + +using Statistics +println("RMSE from model prediction $(round((mean((ŷ - y_valid).^2).^0.5), digits = 8)).") + +# we can also retain / use the best score (based on eval_metric) which is stored in the booster +println("Best RMSE from model training $(round((bst.best_score), digits = 8)).") +``` \ No newline at end of file diff --git a/test/runtests.jl b/test/runtests.jl index d4337c3..0bd1c8c 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -155,13 +155,13 @@ end early_stopping_rounds = 2 ) - nrounds_bst = XGBoost.getnrounds(bst) - nrounds_bst_early_stopping = XGBoost.getnrounds(bst_early_stopping) - # Check to see that running with early stopping results in less than or equal rounds - @test nrounds_bst_early_stopping <= nrounds_bst + nrounds_bst = XGBoost.getnrounds(bst) + nrounds_bst_early_stopping = XGBoost.getnrounds(bst_early_stopping) + # Check to see that running with early stopping results in less than or equal rounds + @test nrounds_bst_early_stopping <= nrounds_bst - # Check number of rounds > early stopping rounds - @test nrounds_bst_early_stopping > 2 + # Check number of rounds > early stopping rounds + @test nrounds_bst_early_stopping > 2 # test the early stopping rounds interface with an OrderedDict data type in the watchlist watchlist_ordered = OrderedDict("train"=>dtrain, "eval"=>dtest) @@ -228,8 +228,8 @@ end early_stopping_rounds = 2 ) - @test XGBoost.getnrounds(bst_early_stopping) > 2 - @test XGBoost.getnrounds(bst_early_stopping) <= nrounds_bst + @test XGBoost.getnrounds(bst_early_stopping) > 2 + @test XGBoost.getnrounds(bst_early_stopping) <= nrounds_bst # test the interface with an empty watchlist (no output) # this should trigger no early stopping rounds @@ -241,8 +241,8 @@ end eval_metric=["rmsle","rmse"], early_stopping_rounds = 2 # this should be ignored ) - - @test XGBoost.getnrounds(bst_empty_watchlist) == nrounds_bst + + @test XGBoost.getnrounds(bst_empty_watchlist) == nrounds_bst end From ec8d066e7d6b86c68a88b1e4ba86429a9b97b2ea Mon Sep 17 00:00:00 2001 From: Wilan Wong Date: Fri, 27 Oct 2023 12:12:38 +1100 Subject: [PATCH 09/15] Added additional info on data types for watchlist --- docs/src/index.md | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/docs/src/index.md b/docs/src/index.md index d743874..7c3c387 100644 --- a/docs/src/index.md +++ b/docs/src/index.md @@ -188,6 +188,12 @@ update!(bst, (X, y), num_round=10) To help prevent overfitting to the training set, it is helpful to use a validation set to evaluate against to ensure that the XGBoost iterations continue to generalise outside training loss reduction. Early stopping provides a convenient way to automatically stop the boosting process if it's observed that the generalisation capability of the model does not improve for `k` rounds. +If there is more than one element in watchlist, by default the last element will be used. This makes it important to use an ordered data structure (OrderedDict) compared to a standard unordered dictionary; as you might not be guaranteed deterministic behaviour. There will be +a warning if you want to execute early stopping mechanism (`early_stopping_rounds > 0`) but have provided a watchlist with type `Dict` with +more than 1 element. + +Similarly, if there is more than one element in eval_metric, by default the last element will be used. + For example: ```julia From 0b8be979670782bfc21df88c648479f15a232e11 Mon Sep 17 00:00:00 2001 From: Wilan Wong Date: Fri, 27 Oct 2023 12:18:11 +1100 Subject: [PATCH 10/15] Annotated OrderedDict to be more obvious. --- docs/src/index.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/src/index.md b/docs/src/index.md index 7c3c387..5f6d190 100644 --- a/docs/src/index.md +++ b/docs/src/index.md @@ -188,7 +188,7 @@ update!(bst, (X, y), num_round=10) To help prevent overfitting to the training set, it is helpful to use a validation set to evaluate against to ensure that the XGBoost iterations continue to generalise outside training loss reduction. Early stopping provides a convenient way to automatically stop the boosting process if it's observed that the generalisation capability of the model does not improve for `k` rounds. -If there is more than one element in watchlist, by default the last element will be used. This makes it important to use an ordered data structure (OrderedDict) compared to a standard unordered dictionary; as you might not be guaranteed deterministic behaviour. There will be +If there is more than one element in watchlist, by default the last element will be used. This makes it important to use an ordered data structure (`OrderedDict`) compared to a standard unordered dictionary; as you might not be guaranteed deterministic behaviour. There will be a warning if you want to execute early stopping mechanism (`early_stopping_rounds > 0`) but have provided a watchlist with type `Dict` with more than 1 element. From 13e4b84d11699c2bfd2b9358a5c7593290029c15 Mon Sep 17 00:00:00 2001 From: Wilan Wong Date: Fri, 27 Oct 2023 15:17:32 +1100 Subject: [PATCH 11/15] Included using statement for OrderedCollection --- test/runtests.jl | 1 + 1 file changed, 1 insertion(+) diff --git a/test/runtests.jl b/test/runtests.jl index 0bd1c8c..8bc4b2c 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -3,6 +3,7 @@ using CUDA: has_cuda, cu import Term using Random, SparseArrays using Test +using OrderedCollections include("utils.jl") From be3236b6da4bfbd04e1ed0dbb3c74006c0370411 Mon Sep 17 00:00:00 2001 From: Wilan Wong Date: Thu, 9 Nov 2023 12:09:37 +1100 Subject: [PATCH 12/15] Moved log message parsing to update! instead of updateone --- src/booster.jl | 60 +++++++++++++++++++++----------------------------- 1 file changed, 25 insertions(+), 35 deletions(-) diff --git a/src/booster.jl b/src/booster.jl index 5ed07ec..82f745b 100644 --- a/src/booster.jl +++ b/src/booster.jl @@ -371,16 +371,8 @@ function updateone!(b::Booster, Xy::DMatrix; update_feature_names::Bool=false, ) xgbcall(XGBoosterUpdateOneIter, b.handle, round_number, Xy.handle) - # obtain the logs if watchlist is present (for early stopping and/or info) - if isempty(watchlist) - msg = nothing - else - msg = evaliter(b, watchlist, round_number) - @info msg - end - #isempty(watchlist) || (msg = evaliter(b, watchlist, round_number)) _maybe_update_feature_names!(b, Xy, update_feature_names) - b, msg + b end function updateone!(b::Booster, Xy::DMatrix, g::AbstractVector{<:Real}, h::AbstractVector{<:Real}; @@ -394,16 +386,8 @@ function updateone!(b::Booster, Xy::DMatrix, g::AbstractVector{<:Real}, h::Abstr g = convert(Vector{Cfloat}, g) h = convert(Vector{Cfloat}, h) xgbcall(XGBoosterBoostOneIter, b.handle, Xy.handle, g, h, length(g)) - # obtain the logs if watchlist is present (for early stopping and/or info) - if isempty(watchlist) - msg = nothing - else - msg = evaliter(b, watchlist, round_number) - @info msg - end - #isempty(watchlist) || (msg = evaliter(b, watchlist, round_number)) _maybe_update_feature_names!(b, Xy, update_feature_names) - b, msg + b end """ @@ -458,20 +442,26 @@ function update!(b::Booster, data, a...; for j ∈ 1:num_round round_number = getnrounds(b) + 1 - b, msg = updateone!(b, data, a...; round_number, watchlist, kw...) - if !isempty(watchlist) && early_stopping_rounds > 0 - score, dataset, metric = extract_metric_value(msg) - if (maximize && score > best_score || (!maximize && score < best_score)) - best_score = score - best_round = j - elseif j - best_round >= early_stopping_rounds - @info( - "Xgboost: Stopping. \n\tBest iteration: $best_round. \n\tNo improvement in $dataset-$metric result in $early_stopping_rounds rounds." - ) - # add additional fields to record the best iteration - b.best_iteration = best_round - b.best_score = best_score - return b + updateone!(b, data, a...; round_number, watchlist, kw...) + + # Evaluate if watchlist is not empty + if !isempty(watchlist) + msg = evaliter(b, watchlist, round_number) + @info msg + if early_stopping_rounds > 0 + score, dataset, metric = extract_metric_value(msg) + if (maximize && score > best_score || (!maximize && score < best_score)) + best_score = score + best_round = j + elseif j - best_round >= early_stopping_rounds + @info( + "Xgboost: Stopping. \n\tBest iteration: $best_round. \n\tNo improvement in $dataset-$metric result in $early_stopping_rounds rounds." + ) + # add additional fields to record the best iteration + b.best_iteration = best_round + b.best_score = best_score + return b + end end end end @@ -585,7 +575,7 @@ ŷ = predict(b, dvalid, ntree_limit = b.best_iteration) """ function xgboost(dm::DMatrix, a...; num_round::Integer=10, - watchlist::Any = Dict("train" => dm), + watchlist::AbstractDict = Dict("train" => dm), early_stopping_rounds::Integer=0, maximize=false, kw... @@ -597,12 +587,12 @@ function xgboost(dm::DMatrix, a...; # We have a watchlist - give a warning if early stopping is provided and watchlist is a Dict type with length > 1 if isa(watchlist, Dict) if early_stopping_rounds > 0 && length(watchlist) > 1 - @warn "Early stopping rounds activated whilst watchlist has more than 1 element. Recommended to provide watchlist as an OrderedDict to ensure deterministic behaviour." + @error "You must supply an OrderedDict type for watchlist if early stopping rounds is enabled." end end if isempty(watchlist) && early_stopping_rounds > 0 - @warn "Early stopping is ignored as provided watchlist is empty." + @error "Watchlist must be supplied if early_stopping_rounds is enabled." end isempty(watchlist) || @info("XGBoost: starting training.") From a24258af6e58a765b0db2d2345341ec324480277 Mon Sep 17 00:00:00 2001 From: Wilan Wong Date: Thu, 9 Nov 2023 12:29:04 +1100 Subject: [PATCH 13/15] Updated documentation and tests. --- docs/src/index.md | 2 +- src/booster.jl | 4 +-- test/runtests.jl | 77 +++++++++++++++++++++++------------------------ 3 files changed, 41 insertions(+), 42 deletions(-) diff --git a/docs/src/index.md b/docs/src/index.md index 5f6d190..a284f4f 100644 --- a/docs/src/index.md +++ b/docs/src/index.md @@ -188,7 +188,7 @@ update!(bst, (X, y), num_round=10) To help prevent overfitting to the training set, it is helpful to use a validation set to evaluate against to ensure that the XGBoost iterations continue to generalise outside training loss reduction. Early stopping provides a convenient way to automatically stop the boosting process if it's observed that the generalisation capability of the model does not improve for `k` rounds. -If there is more than one element in watchlist, by default the last element will be used. This makes it important to use an ordered data structure (`OrderedDict`) compared to a standard unordered dictionary; as you might not be guaranteed deterministic behaviour. There will be +If there is more than one element in watchlist, by default the last element will be used. In this case, you must use an ordered data structure (`OrderedDict`) compared to a standard unordered dictionary otherwise an exception will be generated. There will be a warning if you want to execute early stopping mechanism (`early_stopping_rounds > 0`) but have provided a watchlist with type `Dict` with more than 1 element. diff --git a/src/booster.jl b/src/booster.jl index 82f745b..f8b570c 100644 --- a/src/booster.jl +++ b/src/booster.jl @@ -587,12 +587,12 @@ function xgboost(dm::DMatrix, a...; # We have a watchlist - give a warning if early stopping is provided and watchlist is a Dict type with length > 1 if isa(watchlist, Dict) if early_stopping_rounds > 0 && length(watchlist) > 1 - @error "You must supply an OrderedDict type for watchlist if early stopping rounds is enabled." + error("You must supply an OrderedDict type for watchlist if early stopping rounds is enabled.") end end if isempty(watchlist) && early_stopping_rounds > 0 - @error "Watchlist must be supplied if early_stopping_rounds is enabled." + error("Watchlist must be supplied if early_stopping_rounds is enabled.") end isempty(watchlist) || @info("XGBoost: starting training.") diff --git a/test/runtests.jl b/test/runtests.jl index 8bc4b2c..844b255 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -146,23 +146,30 @@ end objective="binary:logistic", eval_metric=["rmsle","rmse"] ) + + # test if it ran all the way till the end (baseline) + nrounds_bst = XGBoost.getnrounds(bst) + @test nrounds_bst == 30 + + let err = nothing + try + # Check to see that xgboost will error out when watchlist supplied is a dictionary with early_stopping_rounds enabled + bst_early_stopping = xgboost(dtrain, + num_round=30, + watchlist=watchlist, + η=1, + objective="binary:logistic", + eval_metric=["rmsle","rmse"], + early_stopping_rounds = 2 + ) - bst_early_stopping = xgboost(dtrain, - num_round=30, - watchlist=watchlist, - η=1, - objective="binary:logistic", - eval_metric=["rmsle","rmse"], - early_stopping_rounds = 2 - ) - - nrounds_bst = XGBoost.getnrounds(bst) - nrounds_bst_early_stopping = XGBoost.getnrounds(bst_early_stopping) - # Check to see that running with early stopping results in less than or equal rounds - @test nrounds_bst_early_stopping <= nrounds_bst + nrounds_bst = XGBoost.getnrounds(bst) + nrounds_bst_early_stopping = XGBoost.getnrounds(bst_early_stopping) + catch err + end - # Check number of rounds > early stopping rounds - @test nrounds_bst_early_stopping > 2 + @test err isa Exception + end # test the early stopping rounds interface with an OrderedDict data type in the watchlist watchlist_ordered = OrderedDict("train"=>dtrain, "eval"=>dtest) @@ -176,6 +183,8 @@ end early_stopping_rounds = 2 ) + + @test XGBoost.getnrounds(bst_early_stopping) > 2 @test XGBoost.getnrounds(bst_early_stopping) <= nrounds_bst @@ -220,31 +229,21 @@ end # ensure that the results are the same (as numerically possible) with the best round @test abs(bst_early_stopping.best_score - calc_metric) < 1e-9 - # test the interface with no watchlist provided (defaults to the training dataset) - bst_early_stopping = xgboost(dtrain, - num_round=30, - η=1, - objective="binary:logistic", - eval_metric=["rmsle","rmse"], - early_stopping_rounds = 2 - ) - - @test XGBoost.getnrounds(bst_early_stopping) > 2 - @test XGBoost.getnrounds(bst_early_stopping) <= nrounds_bst - - # test the interface with an empty watchlist (no output) - # this should trigger no early stopping rounds - bst_empty_watchlist = xgboost(dtrain, - num_round=30, - η=1, - watchlist = Dict(), - objective="binary:logistic", - eval_metric=["rmsle","rmse"], - early_stopping_rounds = 2 # this should be ignored - ) - - @test XGBoost.getnrounds(bst_empty_watchlist) == nrounds_bst + # Test the interface with no watchlist provided (it'll default to training watchlist) + let err = nothing + try + bst_early_stopping = xgboost(dtrain, + num_round=30, + η=1, + objective="binary:logistic", + eval_metric=["rmsle","rmse"], + early_stopping_rounds = 2 + ) + catch err + end + @test !(err isa Exception) + end end From 61abfeb695add9cc150c9b663d42012b2fa686ec Mon Sep 17 00:00:00 2001 From: Wilan Wong Date: Thu, 9 Nov 2023 12:32:56 +1100 Subject: [PATCH 14/15] Altered the XGBoost method definition to reflect exception states for early stopping rounds and watchlist. --- src/booster.jl | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/src/booster.jl b/src/booster.jl index f8b570c..7f4dde0 100644 --- a/src/booster.jl +++ b/src/booster.jl @@ -533,13 +533,15 @@ This is essentially an alias for constructing a [`Booster`](@ref) with `data` an followed by [`update!`](@ref) for `nrounds`. `watchlist` is a dict the keys of which are strings giving the name of the data to watch -and the values of which are [`DMatrix`](@ref) objects containing the data. It is critical to use an OrderedDict -when utilising early_stopping_rounds to ensure XGBoost uses the correct and intended dataset to perform early stop. +and the values of which are [`DMatrix`](@ref) objects containing the data. It is mandatory to use an OrderedDict +when utilising early_stopping_rounds and there is more than 1 element in watchlist to ensure XGBoost uses the +correct and intended dataset to perform early stop. `early_stopping_rounds` activates early stopping if set to > 0. Validation metric needs to improve at least once in every k rounds. If `watchlist` is not explicitly provided, it will use the training dataset to evaluate the stopping criterion. Otherwise, it will use the last data element in `watchlist` and the -last metric in `eval_metric` (if more than one). Note that early stopping is ignored if `watchlist` is empty. +last metric in `eval_metric` (if more than one). Note that `watchlist` cannot be empty if +`early_stopping_rounds` is enabled. `maximize` If early_stopping_rounds is set, then this parameter must be set as well. When it is false, it means the smaller the evaluation score the better. When set to true, @@ -587,7 +589,7 @@ function xgboost(dm::DMatrix, a...; # We have a watchlist - give a warning if early stopping is provided and watchlist is a Dict type with length > 1 if isa(watchlist, Dict) if early_stopping_rounds > 0 && length(watchlist) > 1 - error("You must supply an OrderedDict type for watchlist if early stopping rounds is enabled.") + error("You must supply an OrderedDict type for watchlist if early stopping rounds is enabled and there is more than one element in watchlist.") end end From b23bffd028b577377a53ccf60b406f2ae1d9139f Mon Sep 17 00:00:00 2001 From: Wilan Wong Date: Thu, 9 Nov 2023 12:44:14 +1100 Subject: [PATCH 15/15] Created exception if extract_metric_value could not find a match when parsing XGBoost logs. --- src/booster.jl | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/booster.jl b/src/booster.jl index 7f4dde0..32f94bb 100644 --- a/src/booster.jl +++ b/src/booster.jl @@ -518,10 +518,10 @@ function extract_metric_value(msg, dataset=nothing, metric=nothing) if match_result != nothing parsed_value = parse(Float64, match_result.captures[1]) return parsed_value, dataset, metric - else - @warn "No match found for pattern: $dataset-$metric in message: $msg" - return nothing end + + # there was no match result - should error out + error("No match found for pattern: $dataset-$metric in message: $msg") end """