Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow InSample() resampling in IteratedModel #61

Merged
merged 8 commits into from
Jun 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ Serialization = "9e88b42a-f829-5b0c-bbe9-9e923198166b"

[compat]
IterationControl = "0.5"
MLJBase = "1"
MLJBase = "1.3"
julia = "1.6"

[extras]
Expand Down
6 changes: 2 additions & 4 deletions src/MLJIteration.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,15 @@ const CONTROLS = vcat(IterationControl.CONTROLS,
:WithModelDo,
:CycleLearningRate,
:Save])

const CONTROLS_LIST = join(map(c->"$c()", CONTROLS), ", ", " and ")
const TRAINING_CONTROLS = [:Step, ]

# export all control types:
for control in CONTROLS
eval(:(export $control))
end

const CONTROLS_DEFAULT = [Step(1),
const DEFAULT_CONTROLS = [Step(1),
Patience(5),
GL(),
TimeLimit(0.03), # about 2 mins
Expand All @@ -42,6 +42,4 @@ include("traits.jl")
include("ic_model.jl")
include("core.jl")



end # module
251 changes: 141 additions & 110 deletions src/constructors.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
const IterationResamplingTypes =
Union{Holdout,Nothing,MLJBase.TrainTestPairs}
Union{Holdout,InSample,Nothing,MLJBase.TrainTestPairs}


## TYPES AND CONSTRUCTOR
Expand Down Expand Up @@ -72,96 +72,119 @@ err_bad_iteration_parameter(p) =
ArgumentError("Model to be iterated does not have :($p) as an iteration parameter. ")

"""
IteratedModel(model=nothing,
controls=$CONTROLS_DEFAULT,
retrain=false,
resampling=Holdout(),
measure=nothing,
weights=nothing,
class_weights=nothing,
operation=predict,
verbosity=1,
check_measure=true,
iteration_parameter=nothing,
cache=true)

Wrap the specified `model <: Supervised` in the specified iteration
`controls`. Training a machine bound to the wrapper iterates a
corresonding machine bound to `model`. Here `model` should support
iteration.

To list all controls, do `MLJIteration.CONTROLS`. Controls are
summarized at
[https://alan-turing-institute.github.io/MLJ.jl/dev/getting_started/](https://alan-turing-institute.github.io/MLJ.jl/dev/controlling_iterative_models/)
but query individual doc-strings for details and advanced options. For
creating your own controls, refer to the documentation just cited.

To make out-of-sample losses available to the controls, the machine
bound to `model` is only trained on part of the data, as iteration
proceeds. See details on training below. Specify `retrain=true`
to ensure the model is retrained on *all* available data, using the
same number of iterations, once controlled iteration has stopped.

Specify `resampling=nothing` if all data is to be used for controlled
iteration, with each out-of-sample loss replaced by the most recent
training loss, assuming this is made available by the model
(`supports_training_losses(model) == true`). Otherwise, `resampling`
must have type `Holdout` (eg, `Holdout(fraction_train=0.8, rng=123)`).

Assuming `retrain=true` or `resampling=nothing`,
`iterated_model` behaves exactly like the original `model` but with
the iteration parameter automatically selected. If
`retrain=false` (default) and `resampling` is not `nothing`, then
`iterated_model` behaves like the original model trained on a subset
of the provided data.

Controlled iteration can be continued with new `fit!` calls (warm
restart) by mutating a control, or by mutating the iteration parameter
of `model`, which is otherwise ignored.


### Training

Given an instance `iterated_model` of `IteratedModel`, calling
`fit!(mach)` on a machine `mach = machine(iterated_model, data...)`
performs the following actions:

- Assuming `resampling !== nothing`, the `data` is split into *train* and
*test* sets, according to the specified `resampling` strategy, which
must have type `Holdout`.

- A clone of the wrapped model, `iterated_model.model`, is bound to
the train data in an internal machine, `train_mach`. If `resampling
=== nothing`, all data is used instead. This machine is the object
to which controls are applied. For example, `Callback(fitted_params
|> print)` will print the value of `fitted_params(train_mach)`.
IteratedModel(model;
controls=MLJIteration.DEFAULT_CONTROLS,
resampling=Holdout(),
measure=nothing,
retrain=false,
advanced_options...,
)

Wrap the specified supervised `model` in the specified iteration `controls`. Here `model`
should support iteration, which is true if (`iteration_parameter(model)` is different from
`nothing`.

Available controls: $CONTROLS_LIST.

!!! important

To make out-of-sample losses available to the controls, the wrapped `model` is only
trained on part of the data, as iteration proceeds. The user may want to force
retraining on all data after controlled iteration has finished by specifying
`retrain=true`. See also "Training", and the `retrain` option, under "Extended help"
below.

# Extended help

# Options

- `controls=$DEFAULT_CONTROLS`: Controls are summarized at
[https://JuliaAI.github.io/MLJ.jl/dev/getting_started/](https://JuliaAI.github.io/MLJ.jl/dev/controlling_iterative_models/)
but query individual doc-strings for details and advanced options. For creating your own
controls, refer to the documentation just cited.

- `resampling=Holdout(fraction_train=0.7)`: The default resampling holds back 30% of data
for computing an out-of-sample estimate of performance (the "loss") for loss-based
controls such as `WithLossDo`. Specify `resampling=nothing` if all data is to be used
for controlled iteration, with each out-of-sample loss replaced by the most recent
training loss, assuming this is made available by the model
(`supports_training_losses(model) == true`). If the model does not report a training
loss, you can use `resampling=InSample()` instead. Otherwise, `resampling` must have
type `Holdout` or be a vector with one element of the form `(train_indices,
test_indices)`.

- `measure=nothing`: StatisticalMeasures.jl compatible measure for estimating model
performance (the "loss", but the orientation is immaterial - i.e., this could be a
score). Inferred by default. Ignored if `resampling=nothing`.

- `retrain=false`: If `retrain=true` or `resampling=nothing`, `iterated_model` behaves
exactly like the original `model` but with the iteration parameter automatically
selected ("learned"). That is, the model is retrained on *all* available data, using the
same number of iterations, once controlled iteration has stopped. This is typically
desired if wrapping the iterated model further, or when inserting in a pipeline or other
composite model. If `retrain=false` (default) and `resampling isa Holdout`, then
`iterated_model` behaves like the original model trained on a subset of the provided
data.

- `weights=nothing`: per-observation weights to be passed to `measure` where supported; if
unspecified, these are understood to be uniform.

- `class_weights=nothing`: class-weights to be passed to `measure` where supported; if
unspecified, these are understood to be uniform.

- `operation=nothing`: Operation, such as `predict` or `predict_mode`, for computing
target values, or proxy target values, for consumption by `measure`; automatically
inferred by default.

- `check_measure=true`: Specify `false` to override checks on `measure` for compatibility
with the training data.

- `iteration_parameter=nothing`: A symbol, such as `:epochs`, naming the iteration
parameter of `model`; inferred by default. Note that the actual value of the iteration
parameter in the supplied `model` is ignored; only the value of an internal clone is
mutated during training the wrapped model.

- `cache=true`: Whether or not model-specific representations of data are cached in
between iteration parameter increments; specify `cache=false` to prioritize memory over
speed.


# Training

Training an instance `iterated_model` of `IteratedModel` on some `data` (by binding to a
machine and calling `fit!`, for example) performs the following actions:

- Assuming `resampling !== nothing`, the `data` is split into *train* and *test* sets,
according to the specified `resampling` strategy.

- A clone of the wrapped model, `model` is bound to the train data in an internal machine,
`train_mach`. If `resampling === nothing`, all data is used instead. This machine is the
object to which controls are applied. For example, `Callback(fitted_params |> print)`
will print the value of `fitted_params(train_mach)`.

- The iteration parameter of the clone is set to `0`.

- The specified `controls` are repeatedly applied to `train_mach` in
sequence, until one of the controls triggers a stop. Loss-based
controls (eg, `Patience()`, `GL()`, `Threshold(0.001)`) use an
out-of-sample loss, obtained by applying `measure` to predictions
and the test target values. (Specifically, these predictions are
those returned by `operation(train_mach)`.) If `resampling ===
nothing` then the most recent training loss is used instead. Some
controls require *both* out-of-sample and training losses (eg,
`PQ()`).
- The specified `controls` are repeatedly applied to `train_mach` in sequence, until one
of the controls triggers a stop. Loss-based controls (eg, `Patience()`, `GL()`,
`Threshold(0.001)`) use an out-of-sample loss, obtained by applying `measure` to
predictions and the test target values. (Specifically, these predictions are those
returned by `operation(train_mach)`.) If `resampling === nothing` then the most recent
training loss is used instead. Some controls require *both* out-of-sample and training
losses (eg, `PQ()`).

- Once a stop has been triggered, a clone of `model` is bound to all
`data` in a machine called `mach_production` below, unless
`retrain == false` or `resampling === nothing`, in which case
`mach_production` coincides with `train_mach`.
- Once a stop has been triggered, a clone of `model` is bound to all `data` in a machine
called `mach_production` below, unless `retrain == false` (true by default) or
`resampling === nothing`, in which case `mach_production` coincides with `train_mach`.


### Prediction
# Prediction

Calling `predict(mach, Xnew)` returns `predict(mach_production,
Xnew)`. Similar similar statements hold for `predict_mean`,
`predict_mode`, `predict_median`.
Calling `predict(mach, Xnew)` in the example above returns `predict(mach_production,
Xnew)`. Similar similar statements hold for `predict_mean`, `predict_mode`,
`predict_median`.


### Controls
# Controls that mutate parameters

A control is permitted to mutate the fields (hyper-parameters) of
`train_mach.model` (the clone of `model`). For example, to mutate a
Expand All @@ -174,11 +197,25 @@ in that parameter, this will trigger retraining of `train_mach` from
scratch, with a different training outcome, which is not recommended.


### Warm restarts
# Warm restarts

If `iterated_model` is mutated and `fit!(mach)` is called again, then
a warm restart is attempted if the only parameters to change are
`model` or `controls` or both.
In the following example, the second `fit!` call will not restart training of the internal
`train_mach`, assuming `model` supports warm restarts:

```julia
iterated_model = IteratedModel(
model,
controls = [Step(1), NumberLimit(100)],
)
mach = machine(iterated_model, X, y)
fit!(mach) # train for 100 iterations
iterated_model.controls = [Step(1), NumberLimit(50)],
fit!(mach) # train for an *extra* 50 iterations
```

More generally, if `iterated_model` is mutated and `fit!(mach)` is called again, then a
warm restart is attempted if the only parameters to change are `model` or `controls` or
both.

Specifically, `train_mach.model` is mutated to match the current value
of `iterated_model.model` and the iteration parameter of the latter is
Expand All @@ -188,14 +225,14 @@ repeated application of the (updated) controls begin anew.
"""
function IteratedModel(args...;
model=nothing,
control=CONTROLS_DEFAULT,
control=DEFAULT_CONTROLS,
controls=control,
resampling=MLJBase.Holdout(),
measures=nothing,
measure=measures,
weights=nothing,
class_weights=nothing,
operation=predict,
operation=nothing,
retrain=false,
check_measure=true,
iteration_parameter=nothing,
Expand All @@ -211,30 +248,24 @@ function IteratedModel(args...;
atom = model
end

options = (
atom,
controls,
resampling,
measure,
weights,
class_weights,
operation,
retrain,
check_measure,
iteration_parameter,
cache,
)

if atom isa Deterministic
iterated_model = DeterministicIteratedModel(atom,
controls,
resampling,
measure,
weights,
class_weights,
operation,
retrain,
check_measure,
iteration_parameter,
cache)
iterated_model = DeterministicIteratedModel(options...)
elseif atom isa Probabilistic
iterated_model = ProbabilisticIteratedModel(atom,
controls,
resampling,
measure,
weights,
class_weights,
operation,
retrain,
check_measure,
iteration_parameter,
cache)
iterated_model = ProbabilisticIteratedModel(options...)
else
throw(ERR_NOT_SUPERVISED)
end
Expand Down
24 changes: 23 additions & 1 deletion test/core.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ using IterationControl
using MLJBase
using MLJModelInterface
using StatisticalMeasures
using StableRNGs
using ..DummyModel

X, y = make_dummy(N=20)
Expand All @@ -26,6 +27,7 @@ model = DummyIterativeModel(n=0)
end
IterationControl.loss(mach::Machine{<:DummyIterativeModel}) =
last(training_losses(mach))

IterationControl.train!(mach, controls..., verbosity=0)
losses1 = report(mach).training_losses
yhat1 = predict(mach, X)
Expand Down Expand Up @@ -104,6 +106,26 @@ model = DummyIterativeModel(n=0)
@test report(mach).n_iterations == 5
end

@testset "resampling = InSample()" begin
model = DummyIterativeModel(n=0, rng=StableRNG(123))
controls=[Step(2), NumberLimit(10)]

# using `resampling=nothing`:
imodel = IteratedModel(model=model, controls=controls, resampling=nothing)
mach = machine(imodel, X, y)
fit!(mach, verbosity=0)
y1 = predict(mach, rows=1:10)

# using `resampling=InSample()`:
model = DummyIterativeModel(n=0, rng=StableRNG(123))
imodel = IteratedModel(model=model, controls=controls, resampling=InSample())
mach = machine(imodel, X, y)
fit!(mach, verbosity=0)
y2 = predict(mach, rows=1:10)

@test y1 == y2
end

@testset "integration: resampling=Holdout()" begin

controls=[Step(2), Patience(4), TimeLimit(0.001)]
Expand Down Expand Up @@ -269,7 +291,7 @@ function MLJBase.restore(::EphemeralRegressor, serialized_fitresult)
end

@testset "save and restore" begin
#https://github.com/alan-turing-institute/MLJ.jl/issues/1099
#https://github.com/JuliaAI/MLJ.jl/issues/1099
X, y = (; x = rand(10)), fill(42.0, 3)
controls = [Step(1), NumberLimit(2)]
imodel = IteratedModel(
Expand Down
Loading