Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add random forest method & documentation changes #22

Merged
merged 2 commits into from
Jul 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "Mice"
uuid = "d4678d24-b338-4f96-a2c8-a66549d61c16"
authors = ["Tom Metherell <[email protected]> and contributors"]
version = "0.3.3"
version = "0.3.4"

[deps]
AxisArrays = "39de3d68-74b9-583c-8d2d-e117c070f3a9"
Expand All @@ -19,15 +19,18 @@ StatsModels = "3eaba693-59b7-5ba5-a881-562e759f1c8d"
Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c"

[weakdeps]
BetaML = "024491cd-cc6b-443e-8034-08ea7eb7db2b"
DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"
RCall = "6f49c342-dc21-5d91-9882-a32aef131414"

[extensions]
MiceBetaMLExt = "BetaML"
MiceDataFramesExt = "DataFrames"
MiceRCallExt = "RCall"

[compat]
AxisArrays = "0.4"
BetaML = "0.8, 0.9, 0.10, 0.11, 0.12"
CategoricalArrays = "0.10"
DataFrames = "1.6"
Distributions = "0.25"
Expand Down
7 changes: 6 additions & 1 deletion docs/make.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,12 @@ makedocs(
pages = ["index.md",
"gettingstarted.md",
"wrangling.md",
"imputation.md",
"Imputation" => [
"mice.md",
"customising-imputation.md",
"diagnostics.md",
"binding-imputations.md"
],
"analysis.md",
"pooling.md",
"rcall.md",
Expand Down
10 changes: 10 additions & 0 deletions docs/src/binding-imputations.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
# Binding imputations together
If you have a number of `Mids` objects that were produced in the same way (e.g. through [multithreading](@ref Multithreading)), you can bind them together into a single `Mids` object using the function `bindImputations`. Note that the log of events might not make sense in the resulting object: it is better to inspect the logs of the individual objects before binding them together.

```@docs
bindImputations
```

```@raw html
<br> <div align="right"> Funded by Wellcome &nbsp;&nbsp;&nbsp; <img src="../wellcome-logo-white.png" style="vertical-align:middle" alt="Wellcome logo" width="50" height="50"> </div>
```
58 changes: 19 additions & 39 deletions docs/src/imputation.md → docs/src/customising-imputation.md
Original file line number Diff line number Diff line change
@@ -1,15 +1,8 @@
# Imputation (`mice`)
The main function of the package is `mice`, which takes a `Tables.jl`-compatible table as its input. It returns a multiply imputed dataset (`Mids`) object with the imputed values.
# Customising the imputation setup

```@docs
Mids
mice
```

## Customising the imputation setup
You can customise various aspects of the imputation setup by passing keyword arguments to `mice`. These are described above. You can also use some of the functions below to define objects that you can customise to alter how `mice` handles the imputation.

### Locations to impute
## Locations to impute
You can customise which data points are imputed by manipulating the `imputeWhere` argument. By default, this will specify that all missing data are to be imputed (using the function `findMissings()`).

```@docs
Expand Down Expand Up @@ -48,7 +41,7 @@ myImputeWhere
mice(myData, imputeWhere = myImputeWhere)
```

### Visit sequence
## Visit sequence
The visit sequence is the order in which the variables are imputed. By default, `mice` sorts the variables in order of missingness (lowest to highest) via the internal function `makeMonotoneSequence`. You can instead define your own visit sequence by creating a vector of variable names in your desired order and passing that to `mice`. For example:

```julia
Expand Down Expand Up @@ -91,7 +84,7 @@ Assuming that the imputations converge normally, changing the visit sequence sho

You can leave variables out of the `visitSequence` to cause `mice()` to not impute them.

### Predictor matrix
## Predictor matrix
The predictor matrix defines which variables in the imputation model are used to predict which others. By default, every variable predicts every other variable, but there are a wide range of cases in which this is not desirable. For example, if your dataset includes an ID column, this is clearly useless for imputation and should be ignored.

To create a default predictor matrix that you can edit, you can use the function `makePredictorMatrix`.
Expand Down Expand Up @@ -152,16 +145,30 @@ Random.seed!(1234); # Set random seed for reproducibility
mice(myData, predictorMatrix = myPredictorMatrix)
```

### Methods
## Methods
The imputation methods are the functions that are used to impute each variable. By default, `mice` uses predictive mean matching (`"pmm"`) for all variables. Currently `Mice.jl` supports the following methods:

| Method | Description | Variable type |
| ------ | ----------- | ------------- |
| `pmm` | Predictive mean matching | Any |
| `rf` | Random forest | Any (but see [below](#rf-warning)) |
| `sample` | Random sample from observed values | Any |
| `mean` | Mean of observed values | Numeric (float) |
| `norm` | Bayesian linear regression | Numeric (float) |

```@raw html
<a name="rf-warning">
</a>
```

!!! warning

If you use `rf` on a variable with integer values, the imputed values will be rounded to the nearest integer in the output. If you want to prevent this behaviour, you have two options:

* Convert the variable to a float before imputation, so it is treated as continuous or

* Convert the variable to a categorical/string array so it is treated as discrete.

The `mean` and `sample` methods should not generally be used.

To create a default methods vector, use the function `makeMethods`.
Expand Down Expand Up @@ -218,31 +225,4 @@ Random.seed!(1234); # Set random seed for reproducibility

# Not run
mice(myData, methods = myMethods)
```

## Diagnostics
After performing multiple imputation, you should inspect the trace plots of the imputed variables to verify convergence. `Mice.jl` includes a plotting function to do this.

```@docs
plot
```

You do need to load the package `Plots.jl` to see the plots:

```julia
using Plots

# Not run
plot(myMids, 7)
```

## Binding imputations together
If you have a number of `Mids` objects that were produced in the same way (e.g. through [multithreading](@ref Multithreading)), you can bind them together into a single `Mids` object using the function `bindImputations`. Note that the log of events might not make sense in the resulting object: it is better to inspect the logs of the individual objects before binding them together.

```@docs
bindImputations
```

```@raw html
<br> <div align="right"> Funded by Wellcome &nbsp;&nbsp;&nbsp; <img src="../wellcome-logo-white.png" style="vertical-align:middle" alt="Wellcome logo" width="50" height="50"> </div>
```
16 changes: 16 additions & 0 deletions docs/src/diagnostics.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
# Diagnostics

After performing multiple imputation, you should inspect the trace plots of the imputed variables to verify convergence. `Mice.jl` includes a plotting function to do this.

```@docs
plot
```

You do need to load the package `Plots.jl` to see the plots:

```julia
using Plots

# Not run
plot(myMids, 7)
```
8 changes: 8 additions & 0 deletions docs/src/mice.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
# Imputation (`mice`)

The main function of the package is `mice`, which takes a `Tables.jl`-compatible table as its input. It returns a multiply imputed dataset (`Mids`) object with the imputed values.

```@docs
Mids
mice
```
49 changes: 49 additions & 0 deletions ext/MiceBetaMLExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
module MiceBetaMLExt
using BetaML: fit!, RandomForestImputer, NONE
using CategoricalArrays: CategoricalArray, CategoricalPool, CategoricalValue, levels
using Mice: makeMethods, mice
import Mice: rfImpute!
using PrecompileTools: @compile_workload
using Random: rand, randperm

function rfImpute!(
y::AbstractArray,
X::Matrix{Float64},
whereY::Vector{Bool};
kwargs...
)

yDecat = y isa CategoricalArray || eltype(y) <: CategoricalValue ? Vector{String}(string.(y)) : y

yX = Matrix{Union{Missing, eltype(yDecat), Float64}}(hcat(yDecat, X))

yX[whereY, 1] .= missing

ŷX = fit!(RandomForestImputer(n_trees = 10, verbosity = NONE; kwargs...), yX)

return y == yDecat ? (eltype(y) <: Integer ? round.(ŷX[whereY, 1], digits = 0) : ŷX[whereY, 1]) : parse.(eltype(levels(y)), ŷX[whereY, 1])
end

@compile_workload begin
catPool = CategoricalPool(["a", "b", "c"])
ct = (
a = Vector{Union{Missing, Int}}(randperm(20)),
b = Vector{Union{Missing, Float64}}(randperm(20)),
c = Vector{Union{Missing, String}}(rand(["a", "b", "c"], 20)),
d = Vector{Union{Missing, Bool}}(rand(Bool, 20)),
e = CategoricalArray{Union{Missing, Int}}(rand([1, 2, 3], 20)),
f = CategoricalArray{Union{Missing, String}}(rand(["a", "b", "c"], 20)),
g = Vector{Union{Missing, CategoricalValue}}(rand([CategoricalValue(catPool, 1), CategoricalValue(catPool, 2), CategoricalValue(catPool, 3)], 20))
)

for col in ct
col[rand(1:20, 1)] .= missing
end

rfMethods = makeMethods(ct)
rfMethods["b"] = "rf"
imputedDataRf = mice(ct, m = 1, iter = 1, methods = rfMethods, progressReports = false)
end

export rfImpute!
end
1 change: 1 addition & 0 deletions src/Mice.jl
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ module Mice
include("pacify.jl")
include("pmmImpute.jl")
include("quantify.jl")
include("rfImpute.jl")
include("sampleImpute.jl")

"""
Expand Down
4 changes: 4 additions & 0 deletions src/rfImpute.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
function rfImpute!
# This function depends on the RandomForestImputer from BetaML.jl.
# Without BetaML.jl, this function will not work.
end
6 changes: 4 additions & 2 deletions src/sampler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ function sampler!(

# For variables using a conditional method
# Also check that there are actually data to be imputed in the column
elseif methods[yVar] ∈ ["norm", "pmm"] && any(whereY)
elseif methods[yVar] ∈ ["norm", "pmm", "rf"] && any(whereY)

# For each imputation
for j in 1:m
Expand Down Expand Up @@ -89,6 +89,8 @@ function sampler!(
workingData[yVar][j][whereY] = pmmImpute!(workingData[yVar][j][.!whereY], X, whereY, whereCount, 5, 1e-5, yVar, iterCounter, j, loggedEvents)
elseif methods[yVar] == "norm"
workingData[yVar][j][whereY] = normImpute!(workingData[yVar][j][.!whereY], X, whereY, whereCount, 1e-5, yVar, iterCounter, j, loggedEvents)
elseif methods[yVar] == "rf"
workingData[yVar][j][whereY] = rfImpute!(workingData[yVar][j], X, whereY)
end
else
# Log an event explaining why the imputation was skipped
Expand All @@ -115,7 +117,7 @@ function sampler!(
if methods[yVar] == ""
push!(loggedEvents, "Iteration $iterCounter, variable $yVar: imputation skipped - no method specified.")
# Invalid method specified
elseif !(methods[yVar] ∈ ["mean", "norm", "pmm", "sample"])
elseif !(methods[yVar] ∈ ["mean", "norm", "pmm", "rf", "sample"])
push!(loggedEvents, "Iteration $iterCounter, variable $yVar: imputation skipped - method not supported.")
# Neither of these => there is no missing data
else
Expand Down
1 change: 1 addition & 0 deletions test/Project.toml
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
[deps]
BetaML = "024491cd-cc6b-443e-8034-08ea7eb7db2b"
CSV = "336ed68f-0bac-5ca0-87d4-7b16caf5d00b"
CategoricalArrays = "324d7699-5711-5eae-9e2f-1d82baa6b597"
DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"
Expand Down
30 changes: 29 additions & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
using CategoricalArrays, CSV, DataFrames, GLM, Mice, Tables, Test, TypedTables
using BetaML, CategoricalArrays, CSV, DataFrames, GLM, Mice, Tables, Test, TypedTables

@testset "Mice (PMM, DF)" begin
data = CSV.read("data/cirrhosis.csv", DataFrame, missingstring = "NA")
Expand Down Expand Up @@ -146,5 +146,33 @@ end

results = pool(analyses)

@test length(results.coefs) == 7
end

@testset "Mice (RF, DF)" begin
data = CSV.read("data/cirrhosis.csv", DataFrame, missingstring = "NA")

data.Stage = categorical(data.Stage)

theMethods = makeMethods(data)
theMethods .= "rf"

predictorMatrix = makePredictorMatrix(data)
predictorMatrix[:, ["ID", "N_Days"]] .= false

imputedData = mice(data, iter = 1, methods = theMethods, predictorMatrix = predictorMatrix, progressReports = false)

@test length(imputedData.loggedEvents) == 0

imputedDataList = listComplete(imputedData)

@test sum(sum.(ismissing.(Matrix.(imputedDataList)))) == 0

analyses = with(imputedData, data -> lm(@formula(N_Days ~ Drug + Age + Stage + Bilirubin), data))

@test length(analyses.analyses) == 5

results = pool(analyses)

@test length(results.coefs) == 7
end
Loading