diff --git a/docs/src/index.md b/docs/src/index.md index 5f6d190..a284f4f 100644 --- a/docs/src/index.md +++ b/docs/src/index.md @@ -188,7 +188,7 @@ update!(bst, (X, y), num_round=10) To help prevent overfitting to the training set, it is helpful to use a validation set to evaluate against to ensure that the XGBoost iterations continue to generalise outside training loss reduction. Early stopping provides a convenient way to automatically stop the boosting process if it's observed that the generalisation capability of the model does not improve for `k` rounds. -If there is more than one element in watchlist, by default the last element will be used. This makes it important to use an ordered data structure (`OrderedDict`) compared to a standard unordered dictionary; as you might not be guaranteed deterministic behaviour. There will be +If there is more than one element in watchlist, by default the last element will be used. In this case, you must use an ordered data structure (`OrderedDict`) compared to a standard unordered dictionary otherwise an exception will be generated. There will be a warning if you want to execute early stopping mechanism (`early_stopping_rounds > 0`) but have provided a watchlist with type `Dict` with more than 1 element. diff --git a/src/booster.jl b/src/booster.jl index 5ed07ec..32f94bb 100644 --- a/src/booster.jl +++ b/src/booster.jl @@ -371,16 +371,8 @@ function updateone!(b::Booster, Xy::DMatrix; update_feature_names::Bool=false, ) xgbcall(XGBoosterUpdateOneIter, b.handle, round_number, Xy.handle) - # obtain the logs if watchlist is present (for early stopping and/or info) - if isempty(watchlist) - msg = nothing - else - msg = evaliter(b, watchlist, round_number) - @info msg - end - #isempty(watchlist) || (msg = evaliter(b, watchlist, round_number)) _maybe_update_feature_names!(b, Xy, update_feature_names) - b, msg + b end function updateone!(b::Booster, Xy::DMatrix, g::AbstractVector{<:Real}, h::AbstractVector{<:Real}; @@ -394,16 +386,8 @@ function updateone!(b::Booster, Xy::DMatrix, g::AbstractVector{<:Real}, h::Abstr g = convert(Vector{Cfloat}, g) h = convert(Vector{Cfloat}, h) xgbcall(XGBoosterBoostOneIter, b.handle, Xy.handle, g, h, length(g)) - # obtain the logs if watchlist is present (for early stopping and/or info) - if isempty(watchlist) - msg = nothing - else - msg = evaliter(b, watchlist, round_number) - @info msg - end - #isempty(watchlist) || (msg = evaliter(b, watchlist, round_number)) _maybe_update_feature_names!(b, Xy, update_feature_names) - b, msg + b end """ @@ -458,20 +442,26 @@ function update!(b::Booster, data, a...; for j ∈ 1:num_round round_number = getnrounds(b) + 1 - b, msg = updateone!(b, data, a...; round_number, watchlist, kw...) - if !isempty(watchlist) && early_stopping_rounds > 0 - score, dataset, metric = extract_metric_value(msg) - if (maximize && score > best_score || (!maximize && score < best_score)) - best_score = score - best_round = j - elseif j - best_round >= early_stopping_rounds - @info( - "Xgboost: Stopping. \n\tBest iteration: $best_round. \n\tNo improvement in $dataset-$metric result in $early_stopping_rounds rounds." - ) - # add additional fields to record the best iteration - b.best_iteration = best_round - b.best_score = best_score - return b + updateone!(b, data, a...; round_number, watchlist, kw...) + + # Evaluate if watchlist is not empty + if !isempty(watchlist) + msg = evaliter(b, watchlist, round_number) + @info msg + if early_stopping_rounds > 0 + score, dataset, metric = extract_metric_value(msg) + if (maximize && score > best_score || (!maximize && score < best_score)) + best_score = score + best_round = j + elseif j - best_round >= early_stopping_rounds + @info( + "Xgboost: Stopping. \n\tBest iteration: $best_round. \n\tNo improvement in $dataset-$metric result in $early_stopping_rounds rounds." + ) + # add additional fields to record the best iteration + b.best_iteration = best_round + b.best_score = best_score + return b + end end end end @@ -528,10 +518,10 @@ function extract_metric_value(msg, dataset=nothing, metric=nothing) if match_result != nothing parsed_value = parse(Float64, match_result.captures[1]) return parsed_value, dataset, metric - else - @warn "No match found for pattern: $dataset-$metric in message: $msg" - return nothing end + + # there was no match result - should error out + error("No match found for pattern: $dataset-$metric in message: $msg") end """ @@ -543,13 +533,15 @@ This is essentially an alias for constructing a [`Booster`](@ref) with `data` an followed by [`update!`](@ref) for `nrounds`. `watchlist` is a dict the keys of which are strings giving the name of the data to watch -and the values of which are [`DMatrix`](@ref) objects containing the data. It is critical to use an OrderedDict -when utilising early_stopping_rounds to ensure XGBoost uses the correct and intended dataset to perform early stop. +and the values of which are [`DMatrix`](@ref) objects containing the data. It is mandatory to use an OrderedDict +when utilising early_stopping_rounds and there is more than 1 element in watchlist to ensure XGBoost uses the +correct and intended dataset to perform early stop. `early_stopping_rounds` activates early stopping if set to > 0. Validation metric needs to improve at least once in every k rounds. If `watchlist` is not explicitly provided, it will use the training dataset to evaluate the stopping criterion. Otherwise, it will use the last data element in `watchlist` and the -last metric in `eval_metric` (if more than one). Note that early stopping is ignored if `watchlist` is empty. +last metric in `eval_metric` (if more than one). Note that `watchlist` cannot be empty if +`early_stopping_rounds` is enabled. `maximize` If early_stopping_rounds is set, then this parameter must be set as well. When it is false, it means the smaller the evaluation score the better. When set to true, @@ -585,7 +577,7 @@ ŷ = predict(b, dvalid, ntree_limit = b.best_iteration) """ function xgboost(dm::DMatrix, a...; num_round::Integer=10, - watchlist::Any = Dict("train" => dm), + watchlist::AbstractDict = Dict("train" => dm), early_stopping_rounds::Integer=0, maximize=false, kw... @@ -597,12 +589,12 @@ function xgboost(dm::DMatrix, a...; # We have a watchlist - give a warning if early stopping is provided and watchlist is a Dict type with length > 1 if isa(watchlist, Dict) if early_stopping_rounds > 0 && length(watchlist) > 1 - @warn "Early stopping rounds activated whilst watchlist has more than 1 element. Recommended to provide watchlist as an OrderedDict to ensure deterministic behaviour." + error("You must supply an OrderedDict type for watchlist if early stopping rounds is enabled and there is more than one element in watchlist.") end end if isempty(watchlist) && early_stopping_rounds > 0 - @warn "Early stopping is ignored as provided watchlist is empty." + error("Watchlist must be supplied if early_stopping_rounds is enabled.") end isempty(watchlist) || @info("XGBoost: starting training.") diff --git a/test/runtests.jl b/test/runtests.jl index 8bc4b2c..844b255 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -146,23 +146,30 @@ end objective="binary:logistic", eval_metric=["rmsle","rmse"] ) + + # test if it ran all the way till the end (baseline) + nrounds_bst = XGBoost.getnrounds(bst) + @test nrounds_bst == 30 + + let err = nothing + try + # Check to see that xgboost will error out when watchlist supplied is a dictionary with early_stopping_rounds enabled + bst_early_stopping = xgboost(dtrain, + num_round=30, + watchlist=watchlist, + η=1, + objective="binary:logistic", + eval_metric=["rmsle","rmse"], + early_stopping_rounds = 2 + ) - bst_early_stopping = xgboost(dtrain, - num_round=30, - watchlist=watchlist, - η=1, - objective="binary:logistic", - eval_metric=["rmsle","rmse"], - early_stopping_rounds = 2 - ) - - nrounds_bst = XGBoost.getnrounds(bst) - nrounds_bst_early_stopping = XGBoost.getnrounds(bst_early_stopping) - # Check to see that running with early stopping results in less than or equal rounds - @test nrounds_bst_early_stopping <= nrounds_bst + nrounds_bst = XGBoost.getnrounds(bst) + nrounds_bst_early_stopping = XGBoost.getnrounds(bst_early_stopping) + catch err + end - # Check number of rounds > early stopping rounds - @test nrounds_bst_early_stopping > 2 + @test err isa Exception + end # test the early stopping rounds interface with an OrderedDict data type in the watchlist watchlist_ordered = OrderedDict("train"=>dtrain, "eval"=>dtest) @@ -176,6 +183,8 @@ end early_stopping_rounds = 2 ) + + @test XGBoost.getnrounds(bst_early_stopping) > 2 @test XGBoost.getnrounds(bst_early_stopping) <= nrounds_bst @@ -220,31 +229,21 @@ end # ensure that the results are the same (as numerically possible) with the best round @test abs(bst_early_stopping.best_score - calc_metric) < 1e-9 - # test the interface with no watchlist provided (defaults to the training dataset) - bst_early_stopping = xgboost(dtrain, - num_round=30, - η=1, - objective="binary:logistic", - eval_metric=["rmsle","rmse"], - early_stopping_rounds = 2 - ) - - @test XGBoost.getnrounds(bst_early_stopping) > 2 - @test XGBoost.getnrounds(bst_early_stopping) <= nrounds_bst - - # test the interface with an empty watchlist (no output) - # this should trigger no early stopping rounds - bst_empty_watchlist = xgboost(dtrain, - num_round=30, - η=1, - watchlist = Dict(), - objective="binary:logistic", - eval_metric=["rmsle","rmse"], - early_stopping_rounds = 2 # this should be ignored - ) - - @test XGBoost.getnrounds(bst_empty_watchlist) == nrounds_bst + # Test the interface with no watchlist provided (it'll default to training watchlist) + let err = nothing + try + bst_early_stopping = xgboost(dtrain, + num_round=30, + η=1, + objective="binary:logistic", + eval_metric=["rmsle","rmse"], + early_stopping_rounds = 2 + ) + catch err + end + @test !(err isa Exception) + end end