Skip to content

Commit

Permalink
Merge pull request #215 from JuliaAI/compact-evaluation-objects
Browse files Browse the repository at this point in the history
Create option to write `CompactPerformanceEvaluation` objects to history
  • Loading branch information
ablaom authored May 6, 2024
2 parents c597f27 + 5143292 commit 8f0c036
Show file tree
Hide file tree
Showing 3 changed files with 86 additions and 49 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ StatisticalMeasuresBase = "c062fc1d-0d66-479b-b6ac-8b44719de4cc"
ComputationalResources = "0.3"
Distributions = "0.22,0.23,0.24, 0.25"
LatinHypercubeSampling = "1.7.2"
MLJBase = "1"
MLJBase = "1.3"
ProgressMeter = "1.7.1"
RecipesBase = "0.8,0.9,1"
StatisticalMeasuresBase = "0.1.1"
Expand Down
97 changes: 60 additions & 37 deletions src/tuned_models.jl
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ mutable struct DeterministicTunedModel{T,M<:DeterministicTypes} <: MLJBase.Deter
acceleration_resampling::AbstractResource
check_measure::Bool
cache::Bool
compact_history::Bool
end

mutable struct ProbabilisticTunedModel{T,M<:ProbabilisticTypes} <: MLJBase.Probabilistic
Expand All @@ -69,6 +70,7 @@ mutable struct ProbabilisticTunedModel{T,M<:ProbabilisticTypes} <: MLJBase.Proba
acceleration_resampling::AbstractResource
check_measure::Bool
cache::Bool
compact_history::Bool
end

const EitherTunedModel{T,M} =
Expand Down Expand Up @@ -176,6 +178,15 @@ key | value
plus other key/value pairs specific to the `tuning` strategy.
Each element of `history` is a property-accessible object with these properties:
key | value
--------------------|--------------------------------------------------
`measure` | vector of measures (metrics)
`measurement` | vector of measurements, one per measure
`per_fold` | vector of vectors of unaggregated per-fold measurements
`evaluation` | full `PerformanceEvaluation`/`CompactPerformaceEvaluation` object
### Complete list of key-word options
- `model`: `Supervised` model prototype that is cloned and mutated to
Expand Down Expand Up @@ -240,27 +251,35 @@ plus other key/value pairs specific to the `tuning` strategy.
user-suplied data; set to `false` to conserve memory. Speed gains
likely limited to the case `resampling isa Holdout`.
- `compact_history=true`: whether to write `CompactPerformanceEvaluation`](@ref) or
regular [`PerformanceEvaluation`](@ref) objects to the history (accessed via the
`:evaluation` key); the compact form excludes some fields to conserve memory.
"""
function TunedModel(args...; model=nothing,
models=nothing,
tuning=nothing,
resampling=MLJBase.Holdout(),
measures=nothing,
measure=measures,
weights=nothing,
class_weights=nothing,
operations=nothing,
operation=operations,
ranges=nothing,
range=ranges,
selection_heuristic=NaiveSelection(),
train_best=true,
repeats=1,
n=nothing,
acceleration=default_resource(),
acceleration_resampling=CPU1(),
check_measure=true,
cache=true)
function TunedModel(
args...;
model=nothing,
models=nothing,
tuning=nothing,
resampling=MLJBase.Holdout(),
measures=nothing,
measure=measures,
weights=nothing,
class_weights=nothing,
operations=nothing,
operation=operations,
ranges=nothing,
range=ranges,
selection_heuristic=NaiveSelection(),
train_best=true,
repeats=1,
n=nothing,
acceleration=default_resource(),
acceleration_resampling=CPU1(),
check_measure=true,
cache=true,
compact_history=true,
)

# user can specify model as argument instead of kwarg:
length(args) < 2 || throw(ERR_TOO_MANY_ARGUMENTS)
Expand Down Expand Up @@ -339,7 +358,8 @@ function TunedModel(args...; model=nothing,
acceleration,
acceleration_resampling,
check_measure,
cache
cache,
compact_history,
)

if M <: DeterministicTypes
Expand Down Expand Up @@ -582,9 +602,10 @@ function assemble_events!(metamodels,
check_measure = resampling_machine.model.check_measure,
repeats = resampling_machine.model.repeats,
acceleration = resampling_machine.model.acceleration,
cache = resampling_machine.model.cache),
resampling_machine.args...; cache=false) for
_ in 2:length(partitions)]...]
cache = resampling_machine.model.cache,
compact = resampling_machine.model.compact
), resampling_machine.args...; cache=false) for
_ in 2:length(partitions)]...]

@sync for (i, parts) in enumerate(partitions)
Threads.@spawn begin
Expand Down Expand Up @@ -736,21 +757,23 @@ function MLJBase.fit(tuned_model::EitherTunedModel{T,M},

# instantiate resampler (`model` to be replaced with mutated
# clones during iteration below):
resampler = Resampler(model=model,
resampling = deepcopy(tuned_model.resampling),
measure = tuned_model.measure,
weights = tuned_model.weights,
class_weights = tuned_model.class_weights,
operation = tuned_model.operation,
check_measure = tuned_model.check_measure,
repeats = tuned_model.repeats,
acceleration = tuned_model.acceleration_resampling,
cache = tuned_model.cache)
resampler = Resampler(
model=model,
resampling = deepcopy(tuned_model.resampling),
measure = tuned_model.measure,
weights = tuned_model.weights,
class_weights = tuned_model.class_weights,
operation = tuned_model.operation,
check_measure = tuned_model.check_measure,
repeats = tuned_model.repeats,
acceleration = tuned_model.acceleration_resampling,
cache = tuned_model.cache,
compact = tuned_model.compact_history,
)
resampling_machine = machine(resampler, data...; cache=false)
history, state = build!(nothing, n, tuning, model, model_buffer, state,
verbosity, acceleration, resampling_machine)


return finalize(
tuned_model,
model_buffer,
Expand Down Expand Up @@ -867,9 +890,9 @@ function MLJBase.reports_feature_importances(model::EitherTunedModel)
end # This is needed in some cases (e.g tuning a `Pipeline`)

function MLJBase.feature_importances(::EitherTunedModel, fitresult, report)
# fitresult here is a machine created using the best_model obtained
# fitresult here is a machine created using the best_model obtained
# from the tuning process.
# The line below will return `nothing` when the model being tuned doesn't
# The line below will return `nothing` when the model being tuned doesn't
# support feature_importances.
return MLJBase.feature_importances(fitresult)
end
Expand Down
36 changes: 25 additions & 11 deletions test/tuned_models.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ using Random
Random.seed!(1234*myid())
using .TestUtilities

begin
begin
N = 30
x1 = rand(N);
x2 = rand(N);
Expand Down Expand Up @@ -157,14 +157,14 @@ end

@testset_accelerated "Feature Importances" accel begin
# the DecisionTreeClassifier in /test/_models/ supports feature importances.
tm0 = TunedModel(
model = trees[1],
measure = rms,
tuning = Grid(),
resampling = CV(nfolds = 5),
range = range(
trees[1], :max_depth, values = 1:10
)
tm0 = TunedModel(
model = trees[1],
measure = rms,
tuning = Grid(),
resampling = CV(nfolds = 5),
range = range(
trees[1], :max_depth, values = 1:10
)
)
@test reports_feature_importances(typeof(tm0))
tm = TunedModel(
Expand Down Expand Up @@ -435,7 +435,7 @@ end
model = DecisionTreeClassifier()
tmodel = TunedModel(models=[model,])
mach = machine(tmodel, X, y)
@test mach isa Machine{<:Any,false}
@test !MLJBase.caches_data(mach)
fit!(mach, verbosity=-1)
@test !isdefined(mach, :data)
MLJBase.Tables.istable(mach.cache[end].fitresult.machine.data[1])
Expand Down Expand Up @@ -490,7 +490,7 @@ end
@test MLJBase.predict(mach2, (; x = rand(2))) fill(42.0, 2)
end

@testset_accelerated "full evaluation object" accel begin
@testset_accelerated "evaluation object" accel begin
X, y = make_regression(100, 2)
dcr = DeterministicConstantRegressor()

Expand All @@ -504,10 +504,24 @@ end
fit!(homach, verbosity=0);
horep = report(homach)
evaluations = getproperty.(horep.history, :evaluation)
@test first(evaluations) isa MLJBase.CompactPerformanceEvaluation
measurements = getproperty.(evaluations, :measurement)
models = getproperty.(evaluations, :model)
@test all(==(measurements[1]), measurements)
@test all(==(dcr), models)

homodel = TunedModel(
models=fill(dcr, 10),
resampling=Holdout(rng=StableRNG(1234)),
acceleration_resampling=accel,
measure=mae,
compact_history=false,
)
homach = machine(homodel, X, y)
fit!(homach, verbosity=0);
horep = report(homach)
evaluations = getproperty.(horep.history, :evaluation)
@test first(evaluations) isa MLJBase.PerformanceEvaluation
end

true

0 comments on commit 8f0c036

Please sign in to comment.