From 91b9176626d829d2d5f6f3e394062cd67bb3a643 Mon Sep 17 00:00:00 2001 From: ExpandingMan Date: Tue, 21 Nov 2023 11:43:14 -0500 Subject: [PATCH] allow watchlist to be NamedTuple --- Project.toml | 2 +- src/XGBoost.jl | 1 - src/booster.jl | 6 +++--- test/runtests.jl | 9 +++++++++ 4 files changed, 13 insertions(+), 5 deletions(-) diff --git a/Project.toml b/Project.toml index 8f38b9b..7daa4a8 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "XGBoost" uuid = "009559a3-9522-5dbb-924b-0b6ed2b22bb9" -version = "2.5.0" +version = "2.5.1" [deps] AbstractTrees = "1520ce14-60c1-5f80-bbc7-55ef81b5835c" diff --git a/src/XGBoost.jl b/src/XGBoost.jl index 856cecd..e1756da 100644 --- a/src/XGBoost.jl +++ b/src/XGBoost.jl @@ -23,7 +23,6 @@ include("Lib.jl") using .Lib using .Lib: DMatrixHandle, BoosterHandle - const LOG_LEVEL_REGEX = r"\[.*\] (\D*): " function xgblog(s::Cstring) diff --git a/src/booster.jl b/src/booster.jl index e4df32c..ef24f93 100644 --- a/src/booster.jl +++ b/src/booster.jl @@ -428,7 +428,7 @@ for custom loss. """ function update!(b::Booster, data, a...; num_round::Integer=1, - watchlist::Any = Dict("train" => data), + watchlist=Dict("train" => data), early_stopping_rounds::Integer=0, maximize=false, kw..., @@ -578,7 +578,7 @@ ŷ = predict(b, dvalid, ntree_limit = b.best_iteration) """ function xgboost(dm::DMatrix, a...; num_round::Integer=10, - watchlist::AbstractDict = Dict("train" => dm), + watchlist=Dict("train" => dm), early_stopping_rounds::Integer=0, maximize=false, kw... @@ -590,7 +590,7 @@ function xgboost(dm::DMatrix, a...; # We have a watchlist - give a warning if early stopping is provided and watchlist is a Dict type with length > 1 if isa(watchlist, Dict) if early_stopping_rounds > 0 && length(watchlist) > 1 - error("You must supply an OrderedDict type for watchlist if early stopping rounds is enabled and there is more than one element in watchlist.") + error("You must supply an OrderedDict or NamedTuple type for watchlist if early stopping rounds is enabled and there is more than one element in watchlist.") end end diff --git a/test/runtests.jl b/test/runtests.jl index 844b255..aea0a9e 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -183,7 +183,16 @@ end early_stopping_rounds = 2 ) + watchlist_nt = (train=dtrain, eval=dtest) + bst_early_stopping = xgboost(dtrain, + num_round=30, + watchlist=watchlist_nt, + η=1, + objective="binary:logistic", + eval_metric=["rmsle","rmse"], + early_stopping_rounds = 2 + ) @test XGBoost.getnrounds(bst_early_stopping) > 2 @test XGBoost.getnrounds(bst_early_stopping) <= nrounds_bst