Skip to content

Commit

Permalink
Merge pull request #201 from JuliaAI/explicit-better-checks
Browse files Browse the repository at this point in the history
Add prediction type check for Explicit strategy
  • Loading branch information
ablaom authored Jan 23, 2024
2 parents f5256c5 + 53700b3 commit da15fb3
Show file tree
Hide file tree
Showing 5 changed files with 68 additions and 9 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "MLJTuning"
uuid = "03970b2e-30c4-11ea-3135-d1576263f10f"
authors = ["Anthony D. Blaom <[email protected]>"]
version = "0.8.0"
version = "0.8.1"

[deps]
ComputationalResources = "ed09eef8-17a6-5b46-8889-db040fac31e3"
Expand Down
30 changes: 24 additions & 6 deletions src/strategies/explicit.jl
Original file line number Diff line number Diff line change
@@ -1,15 +1,22 @@
const WARN_INCONSISTENT_PREDICTION_TYPE =
"Not all models to be evaluated have the same prediction type, and this may "*
"cause problems for some measures. For example, a probabilistic metric "*
"like `log_loss` cannot be applied to a model making point (deterministic) "*
"predictions. Inspect the prediction type with "*
"`prediction_type(model)`. "

mutable struct Explicit <: TuningStrategy end

struct ExplicitState{R, N}
range::R # a model-generating iterator
next::N # to hold output of `iterate(range)`
next::N # to hold output of `iterate(range)`
prediction_type::Symbol
user_warned::Bool
end

ExplictState(r::R, n::N) where {R,N} = ExplicitState{R, Union{Nothing, N}}(r, n)

function MLJTuning.setup(tuning::Explicit, model, range, n, verbosity)
next = iterate(range)
return ExplicitState(range, next)
return ExplicitState(range, next, MLJBase.prediction_type(model), false)
end

# models! returns as many models as possible but no more than `n_remaining`:
Expand All @@ -20,11 +27,21 @@ function MLJTuning.models(tuning::Explicit,
n_remaining,
verbosity)

range, next = state.range, state.next
range, next, prediction_type, user_warned =
state.range, state.next, state.prediction_type, state.user_warned

function check(m)
if !user_warned && verbosity > -1 && MLJBase.prediction_type(m) != prediction_type
@warn WARN_INCONSISTENT_PREDICTION_TYPE
user_warned = true
end
end

next === nothing && return nothing, state

m, s = next
check(m)

models = Any[m, ] # types not known until run-time

next = iterate(range, s)
Expand All @@ -33,12 +50,13 @@ function MLJTuning.models(tuning::Explicit,
while i < n_remaining
next === nothing && break
m, s = next
check(m)
push!(models, m)
i += 1
next = iterate(range, s)
end

new_state = ExplicitState(range, next)
new_state = ExplicitState(range, next, prediction_type, user_warned)

return models, new_state

Expand Down
4 changes: 3 additions & 1 deletion src/tuned_models.jl
Original file line number Diff line number Diff line change
Expand Up @@ -431,9 +431,11 @@ function event!(metamodel,
state)
model = _first(metamodel)
metadata = _last(metamodel)
force = typeof(resampling_machine.model.model) !=
typeof(model)
resampling_machine.model.model = model
verb = (verbosity >= 2 ? verbosity - 3 : verbosity - 1)
fit!(resampling_machine, verbosity=verb)
fit!(resampling_machine; verbosity=verb, force)
E = evaluate(resampling_machine)
entry0 = (model = model,
measure = E.measure,
Expand Down
39 changes: 39 additions & 0 deletions test/strategies/explicit.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
good = KNNClassifier(K=2)
bad = KNNClassifier(K=10)
ugly = ConstantClassifier()
evil = DeterministicConstantClassifier()

r = [good, bad, ugly]

Expand Down Expand Up @@ -44,4 +45,42 @@ X, y = make_blobs(rng=rng)
@test_throws ArgumentError TunedModel(; models=[dcc, dcc])
end

r = [good, bad, evil, ugly]

@testset "inconsistent prediction types" begin
# case where different predictions types is actually okay (but still
# a warning is issued):
tmodel = TunedModel(
models=r,
resampling = Holdout(),
measure=accuracy,
)
@test_logs(
(:warn, MLJTuning.WARN_INCONSISTENT_PREDICTION_TYPE),
MLJBase.fit(tmodel, 0, X, y),
);

# verbosity = -1 suppresses the warning:
@test_logs(
MLJBase.fit(tmodel, -1, X, y),
);

# case where there really is a problem with different prediction types:
tmodel = TunedModel(
models=r,
resampling = Holdout(),
measure=log_loss,
)
@test_logs(
(:warn, MLJTuning.WARN_INCONSISTENT_PREDICTION_TYPE),
(:error,),
(:info,),
(:info,),
@test_throws(
ArgumentError, # indicates the problem is with incompatible measure
MLJBase.fit(tmodel, 0, X, y),
)
)
end

true
2 changes: 1 addition & 1 deletion test/tuned_models.jl
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ results = [(evaluate(model, X, y,
tm = TunedModel(
models=r,
resampling=CV(nfolds=2),
measures=cross_entropy
measures=cross_entropy,
)
@test_logs((:error, r"Problem"),
(:info, r""),
Expand Down

0 comments on commit da15fb3

Please sign in to comment.