Skip to content

Commit

Permalink
Merge pull request #213 from JuliaAI/featimp
Browse files Browse the repository at this point in the history
add feature importances support for tuned models
  • Loading branch information
OkonSamuel authored Mar 24, 2024
2 parents 9450378 + b9d8c01 commit ac447ec
Show file tree
Hide file tree
Showing 4 changed files with 80 additions and 128 deletions.
16 changes: 16 additions & 0 deletions src/tuned_models.jl
Original file line number Diff line number Diff line change
Expand Up @@ -857,6 +857,22 @@ function MLJBase.training_losses(tuned_model::EitherTunedModel, _report)
return ret
end

## Support for Feature Importances
function MLJBase.reports_feature_importances(::Type{<:EitherTunedModel{<:Any,M}}) where {M}
return MLJBase.reports_feature_importances(M)
end

function MLJBase.reports_feature_importances(model::EitherTunedModel)
return MLJBase.reports_feature_importances(model.model)
end # This is needed in some cases (e.g tuning a `Pipeline`)

function MLJBase.feature_importances(::EitherTunedModel, fitresult, report)
# fitresult here is a machine created using the best_model obtained
# from the tuning process.
# The line below will return `nothing` when the model being tuned doesn't
# support feature_importances.
return MLJBase.feature_importances(fitresult)
end

## METADATA

Expand Down
18 changes: 16 additions & 2 deletions test/models/DecisionTree.jl
Original file line number Diff line number Diff line change
Expand Up @@ -95,8 +95,11 @@ function MLJBase.fit(model::DecisionTreeClassifier, verbosity::Int, X, y)
#> empty values):

cache = nothing
report = (classes_seen=classes_seen,
print_tree=TreePrinter(tree))
report = (
classes_seen=classes_seen,
print_tree=TreePrinter(tree),
features=collect(Tables.columnnames(Tables.columns(X)))
)

return fitresult, cache, report
end
Expand Down Expand Up @@ -134,6 +137,17 @@ function MLJBase.predict(model::DecisionTreeClassifier
for i in 1:size(y_probabilities, 1)]
end

MLJBase.reports_feature_importances(::Type{<:DecisionTreeClassifier}) = true

function MMI.feature_importances(m::DecisionTreeClassifier, fitresult, report)
features = report.features
fi = DecisionTree.impurity_importance(first(fitresult), normalize=true)
fi_pairs = Pair.(features, fi)
# sort descending
sort!(fi_pairs, by= x->-x[2])

return fi_pairs
end

## REGRESSOR

Expand Down
113 changes: 0 additions & 113 deletions test/schizo.md

This file was deleted.

61 changes: 48 additions & 13 deletions test/tuned_models.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,19 +13,24 @@ using Random
Random.seed!(1234*myid())
using .TestUtilities

N = 30
x1 = rand(N);
x2 = rand(N);
x3 = rand(N);
X = (; x1, x2, x3);
y = 2*x1 .+ 5*x2 .- 3*x3 .+ 0.4*rand(N);

m(K) = KNNRegressor(; K)
r = [m(K) for K in 13:-1:2]

# TODO: replace the above with the line below and post an issue on
# the failure (a bug in Distributed, I reckon):
# r = (m(K) for K in 13:-1:2)
begin
N = 30
x1 = rand(N);
x2 = rand(N);
x3 = rand(N);
X = (; x1, x2, x3);
y = 2*x1 .+ 5*x2 .- 3*x3 .+ 0.4*rand(N);

m(K) = KNNRegressor(; K)
r = [m(K) for K in 13:-1:2]

Xtree, yhat = @load_iris
trees = [DecisionTreeClassifier(pruning_purity = rand()) for _ in 13:-1:2]

# TODO: replace the above with the line below and post an issue on
# the failure (a bug in Distributed, I reckon):
# r = (m(K) for K in 13:-1:2)
end

@testset "constructor" begin
@test_throws(MLJTuning.ERR_SPECIFY_RANGE,
Expand Down Expand Up @@ -105,6 +110,10 @@ end
@test _report.best_model == collect(r)[best_index]
@test _report.history[5] == MLJTuning.delete(history[5], :metadata)

# feature_importances:
# This should return nothing as `KNNRegressor` doesn't support feature_importances
@test feature_importances(tm, fitresult, _report) === nothing

# training_losses:
losses = training_losses(tm, _report)
@test all(eachindex(losses)) do i
Expand Down Expand Up @@ -146,6 +155,32 @@ end
@test results4 results
end

@testset_accelerated "Feature Importances" accel begin
# the DecisionTreeClassifier in /test/_models/ supports feature importances.
tm0 = TunedModel(
model = trees[1],
measure = rms,
tuning = Grid(),
resampling = CV(nfolds = 5),
range = range(
trees[1], :max_depth, values = 1:10
)
)
@test reports_feature_importances(typeof(tm0))
tm = TunedModel(
models = trees,
resampling = CV(nfolds=2),
measures = cross_entropy,
acceleration = CPU1(),
acceleration_resampling = accel
)
@test reports_feature_importances(tm)
fitresult, _, report = MLJBase.fit(tm, 0, Xtree, yhat)
features = first.(feature_importances(tm, fitresult, report))
@test Set(features) == Set(keys(Xtree))

end

@testset_accelerated(
"under/over supply of models",
accel,
Expand Down

0 comments on commit ac447ec

Please sign in to comment.