diff --git a/src/booster.jl b/src/booster.jl index 2769b89..96d1814 100644 --- a/src/booster.jl +++ b/src/booster.jl @@ -366,8 +366,10 @@ function updateone!(b::Booster, Xy::DMatrix; update_feature_names::Bool=false, ) xgbcall(XGBoosterUpdateOneIter, b.handle, round_number, Xy.handle) + isempty(watchlist) || (msg = evaliter(b, watchlist, round_number)) + @info msg _maybe_update_feature_names!(b, Xy, update_feature_names) - b + b, msg end function updateone!(b::Booster, Xy::DMatrix, g::AbstractVector{<:Real}, h::AbstractVector{<:Real}; @@ -381,9 +383,10 @@ function updateone!(b::Booster, Xy::DMatrix, g::AbstractVector{<:Real}, h::Abstr g = convert(Vector{Cfloat}, g) h = convert(Vector{Cfloat}, h) xgbcall(XGBoosterBoostOneIter, b.handle, Xy.handle, g, h, length(g)) - isempty(watchlist) || logeval(b, watchlist, round_number) + isempty(watchlist) || (msg = evaliter(b, watchlist, round_number)) + @info msg _maybe_update_feature_names!(b, Xy, update_feature_names) - b + b, msg end """ @@ -430,39 +433,88 @@ function update!(b::Booster, data, a...; ) if !isempty(watchlist) && early_stopping_rounds > 0 - early_stopping_set = collect(Iterators.map(string, keys(watchlist)))[end] - println("Will train until there has been no improvement in $early_stopping_rounds rounds.\n") + @info("Will train until there has been no improvement in $early_stopping_rounds rounds.\n") best_round = 0 best_score = maximize ? -Inf : Inf end for j ∈ 1:num_round round_number = getnrounds(b) + 1 - updateone!(b, data, a...; round_number, kw...) - if !isempty(watchlist) - msg = evaliter(b, watchlist, j) - @info(msg) - if early_stopping_rounds > 0 - split_msg = split(msg, r"\s+|:") - metric_name_idx = findfirst(occursin.(early_stopping_set, split_msg)) - score = parse(Float64, split_msg[metric_name_idx + 1]) - - if (maximize && score > best_score || (!maximize && score < best_score)) - best_score = score - best_round = j - elseif j - best_round >= early_stopping_rounds - watchlist_metric = (split_msg[metric_name_idx]) - @info( - "Xgboost: Stopping. \n\tBest iteration: $best_round. \n\t$(watchlist_metric) has not improved in $early_stopping_rounds rounds." - ) - return (b) - end + b, msg = updateone!(b, data, a...; round_number, kw...) + if !isempty(watchlist) && early_stopping_rounds > 0 + score, dataset, metric = extract_metric_value(msg) + if (maximize && score > best_score || (!maximize && score < best_score)) + best_score = score + best_round = j + elseif j - best_round >= early_stopping_rounds + @info( + "Xgboost: Stopping. \n\tBest iteration: $best_round. \n\tNo improvement in $dataset-$metric result in $early_stopping_rounds rounds." + ) + return (b) end end end b end + + +""" + extract_metric_value(msg, dataset=nothing, metric=nothing) + +Extracts a numeric value from a message based on the specified dataset and metric. +If dataset or metric is not provided, the function will automatically find the last +mentioned dataset or metric in the message. + +# Arguments +- `msg::AbstractString`: The message containing the numeric values. +- `dataset::Union{AbstractString, Nothing}`: The dataset to extract values for (default: `nothing`). +- `metric::Union{AbstractString, Nothing}`: The metric to extract values for (default: `nothing`). + +# Returns +- Returns the parsed Float64 value if a match is found, otherwise returns `nothing`. + +# Examples +```julia +msg = "train-rmsle:0.09516384803222511 train-rmse:0.12458323318968342 eval-rmsle:0.09311178520817574 eval-rmse:0.12088154560829874" + +# Without specifying dataset and metric +value_without_params = extract_metric_value(msg) +println(value_without_params) # Output: (0.09311178520817574, "eval", "rmsle") + +# With specifying dataset and metric +value_with_params = extract_metric_value(msg, "train", "rmsle") +println(value_with_params) # Output: (0.0951638480322251, "train", "rmsle") +""" + +function extract_metric_value(msg, dataset=nothing, metric=nothing) + if isnothing(dataset) + # Find the last mentioned dataset + datasets = Set([m.match for m in eachmatch(r"\w+(?=-)", msg)]) + dataset = last(collect(datasets)) + end + + if isnothing(metric) + # Find the first mentioned metric + metrics = Set([m.match for m in eachmatch(r"(?<=-)\w+", msg)]) + metric = last(collect(metrics)) + end + + pattern = Regex("$dataset-$metric:([\\d.]+)") + + match_result = match(pattern, msg) + + if match_result != nothing + parsed_value = parse(Float64, match_result.captures[1]) + return parsed_value, dataset, metric + else + @warn "No match found for pattern: $dataset-$metric in message: $msg" + return nothing + end +end + + + """ xgboost(data; num_round=10, watchlist=Dict(), kw...) xgboost(data, ℓ′, ℓ″; kw...)