Skip to content

Commit

Permalink
Merge pull request #36 from OutlierDetectionJL/compathelper/new_versi…
Browse files Browse the repository at this point in the history
…on/2022-11-08-00-53-17-725-02140817060

CompatHelper: bump compat for MLJBase to 0.21, (keep existing compat)
  • Loading branch information
davnn authored Sep 29, 2023
2 parents 5ed27b9 + d1fff62 commit 77528ee
Show file tree
Hide file tree
Showing 4 changed files with 8 additions and 30 deletions.
4 changes: 2 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "OutlierDetection"
uuid = "262411bb-c475-4342-ba9e-03b8c0183ca6"
authors = ["David Muhr <[email protected]> and contributors"]
version = "0.3.3"
version = "0.3.4"

[deps]
MLJBase = "a7f614a8-145f-11e9-1d2a-a57a1082229d"
Expand All @@ -10,7 +10,7 @@ SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"

[compat]
MLJBase = "~0.20.12"
MLJBase = "0.21"
OutlierDetectionInterface = "~0.1.8"
SpecialFunctions = "1, 2"
julia = "^1.6"
Expand Down
4 changes: 1 addition & 3 deletions src/mlj_helpers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -143,9 +143,7 @@ function augmented_transform(mach::MLJ.Machine{<:OD.Detector}; rows=:)
end

function get_scores_from_composite_report(mach)
# new #banana API
# fit_report = MLJ.report_given_method(mach)[:fit]
fit_report = mach.report
fit_report = MLJ.report_given_method(mach)[:fit]
if haskey(fit_report, :additions) && haskey(fit_report.additions, :scores)
return fit_report.additions.scores
else
Expand Down
2 changes: 1 addition & 1 deletion src/mlj_wrappers.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# Determine the input scitype of an array of detectors
to_input_scitype(detectors) = MLJ.glb(MLJ.input_scitype.(detectors)...)
to_input_scitype(detectors) = MLJ._glb(MLJ.input_scitype.(detectors)...)

# Determine the supported composite detector type
ex_to_eltype(type_symbol) = type_symbol == :Unsupervised ? UnsupervisedDetector :
Expand Down
28 changes: 4 additions & 24 deletions test/tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -60,12 +60,12 @@ raw_unsupervised_detector = MMIUnsupervisedDetector()
supervised_detector = ODSupervisedDetector()
raw_supervised_detector = MMISupervisedDetector()

basic_unsupervised_detectors = [unsupervised_detector, raw_unsupervised_detector]
basic_supervised_detectors = [supervised_detector, raw_supervised_detector]
unsupervised_detectors = [unsupervised_detector, raw_unsupervised_detector]
supervised_detectors = [supervised_detector, raw_supervised_detector]

# raw machines
unsupervised_machines = [fit!(machine(detector, X)) for detector in basic_unsupervised_detectors]
supervised_machines = [fit!(machine(detector, X, y)) for detector in basic_supervised_detectors]
unsupervised_machines = [fit!(machine(detector, X)) for detector in unsupervised_detectors]
supervised_machines = [fit!(machine(detector, X, y)) for detector in supervised_detectors]

# surrogate machines
unsupervised_surrogate = score_surrogate_machine(unsupervised_detector, Xs, ys) |> fit!
Expand Down Expand Up @@ -103,22 +103,8 @@ deterministic_surrogate_machines = [
supervised_deterministic_surrogate,
raw_supervised_deterministic_surrogate]

# surrogate detectors
@from_network unsupervised_surrogate mutable struct CustomUnsupervisedDetector end
@from_network raw_unsupervised_surrogate mutable struct RawCustomUnsupervisedDetector end
@from_network supervised_surrogate mutable struct CustomSupervisedDetector end
@from_network raw_supervised_surrogate mutable struct RawCustomSupervisedDetector end

surrogate_unsupervised = CustomUnsupervisedDetector()
raw_surrogate_unsupervised = RawCustomUnsupervisedDetector()
surrogate_supervised = CustomSupervisedDetector()
raw_surrogate_supervised = RawCustomSupervisedDetector()

# composite machines
unsupervised_detectors = [basic_unsupervised_detectors..., surrogate_unsupervised, raw_surrogate_unsupervised]
unsupervised_detectors = [unsupervised_detectors..., map(CompositeDetector, unsupervised_detectors)...]

supervised_detectors = [basic_supervised_detectors..., surrogate_supervised, raw_surrogate_supervised]
supervised_detectors = [supervised_detectors..., map(CompositeDetector, supervised_detectors)...]

# create composite detectors from raw detectors and already wrapped detectors
Expand Down Expand Up @@ -313,12 +299,6 @@ end
end

@testset "erroneous wrapper calls" begin
# wrappers do not work with models other than detectors
static_model = MLJBase.WrappedFunction(identity)
@test_throws MethodError CompositeDetector(static_model)
@test_throws MethodError ProbabilisticDetector(static_model)
@test_throws MethodError DeterministicDetector(static_model)

# wrappers do not work with multiple unnamed detectors
@test_throws ArgumentError CompositeDetector(unsupervised_detector, supervised_detector)
@test_throws ArgumentError ProbabilisticDetector(unsupervised_detector, supervised_detector)
Expand Down

0 comments on commit 77528ee

Please sign in to comment.