Skip to content

Commit

Permalink
Merge pull request #2 from jahan-ai/early_stopping_rounds
Browse files Browse the repository at this point in the history
modifications pending PR to dmlc/XGBoost
  • Loading branch information
david-sun-1 authored Oct 11, 2023
2 parents 7ed00c4 + 5c9baba commit 1e9040b
Showing 1 changed file with 76 additions and 24 deletions.
100 changes: 76 additions & 24 deletions src/booster.jl
Original file line number Diff line number Diff line change
Expand Up @@ -366,8 +366,10 @@ function updateone!(b::Booster, Xy::DMatrix;
update_feature_names::Bool=false,
)
xgbcall(XGBoosterUpdateOneIter, b.handle, round_number, Xy.handle)
isempty(watchlist) || (msg = evaliter(b, watchlist, round_number))
@info msg
_maybe_update_feature_names!(b, Xy, update_feature_names)
b
b, msg
end

function updateone!(b::Booster, Xy::DMatrix, g::AbstractVector{<:Real}, h::AbstractVector{<:Real};
Expand All @@ -381,9 +383,10 @@ function updateone!(b::Booster, Xy::DMatrix, g::AbstractVector{<:Real}, h::Abstr
g = convert(Vector{Cfloat}, g)
h = convert(Vector{Cfloat}, h)
xgbcall(XGBoosterBoostOneIter, b.handle, Xy.handle, g, h, length(g))
isempty(watchlist) || logeval(b, watchlist, round_number)
isempty(watchlist) || (msg = evaliter(b, watchlist, round_number))
@info msg
_maybe_update_feature_names!(b, Xy, update_feature_names)
b
b, msg
end

"""
Expand Down Expand Up @@ -430,39 +433,88 @@ function update!(b::Booster, data, a...;
)

if !isempty(watchlist) && early_stopping_rounds > 0
early_stopping_set = collect(Iterators.map(string, keys(watchlist)))[end]
println("Will train until there has been no improvement in $early_stopping_rounds rounds.\n")
@info("Will train until there has been no improvement in $early_stopping_rounds rounds.\n")
best_round = 0
best_score = maximize ? -Inf : Inf
end

for j 1:num_round
round_number = getnrounds(b) + 1
updateone!(b, data, a...; round_number, kw...)
if !isempty(watchlist)
msg = evaliter(b, watchlist, j)
@info(msg)
if early_stopping_rounds > 0
split_msg = split(msg, r"\s+|:")
metric_name_idx = findfirst(occursin.(early_stopping_set, split_msg))
score = parse(Float64, split_msg[metric_name_idx + 1])

if (maximize && score > best_score || (!maximize && score < best_score))
best_score = score
best_round = j
elseif j - best_round >= early_stopping_rounds
watchlist_metric = (split_msg[metric_name_idx])
@info(
"Xgboost: Stopping. \n\tBest iteration: $best_round. \n\t$(watchlist_metric) has not improved in $early_stopping_rounds rounds."
)
return (b)
end
b, msg = updateone!(b, data, a...; round_number, kw...)
if !isempty(watchlist) && early_stopping_rounds > 0
score, dataset, metric = extract_metric_value(msg)
if (maximize && score > best_score || (!maximize && score < best_score))
best_score = score
best_round = j
elseif j - best_round >= early_stopping_rounds
@info(
"Xgboost: Stopping. \n\tBest iteration: $best_round. \n\tNo improvement in $dataset-$metric result in $early_stopping_rounds rounds."
)
return (b)
end
end
end
b
end



"""
extract_metric_value(msg, dataset=nothing, metric=nothing)
Extracts a numeric value from a message based on the specified dataset and metric.
If dataset or metric is not provided, the function will automatically find the last
mentioned dataset or metric in the message.
# Arguments
- `msg::AbstractString`: The message containing the numeric values.
- `dataset::Union{AbstractString, Nothing}`: The dataset to extract values for (default: `nothing`).
- `metric::Union{AbstractString, Nothing}`: The metric to extract values for (default: `nothing`).
# Returns
- Returns the parsed Float64 value if a match is found, otherwise returns `nothing`.
# Examples
```julia
msg = "train-rmsle:0.09516384803222511 train-rmse:0.12458323318968342 eval-rmsle:0.09311178520817574 eval-rmse:0.12088154560829874"
# Without specifying dataset and metric
value_without_params = extract_metric_value(msg)
println(value_without_params) # Output: (0.09311178520817574, "eval", "rmsle")
# With specifying dataset and metric
value_with_params = extract_metric_value(msg, "train", "rmsle")
println(value_with_params) # Output: (0.0951638480322251, "train", "rmsle")
"""

function extract_metric_value(msg, dataset=nothing, metric=nothing)
if isnothing(dataset)
# Find the last mentioned dataset
datasets = Set([m.match for m in eachmatch(r"\w+(?=-)", msg)])
dataset = last(collect(datasets))
end

if isnothing(metric)
# Find the first mentioned metric
metrics = Set([m.match for m in eachmatch(r"(?<=-)\w+", msg)])
metric = last(collect(metrics))
end

pattern = Regex("$dataset-$metric:([\\d.]+)")

match_result = match(pattern, msg)

if match_result != nothing
parsed_value = parse(Float64, match_result.captures[1])
return parsed_value, dataset, metric
else
@warn "No match found for pattern: $dataset-$metric in message: $msg"
return nothing
end
end



"""
xgboost(data; num_round=10, watchlist=Dict(), kw...)
xgboost(data, ℓ′, ℓ″; kw...)
Expand Down

0 comments on commit 1e9040b

Please sign in to comment.