Skip to content

Commit

Permalink
add rng argument to splitobs, undersample, oversample (#186)
Browse files Browse the repository at this point in the history
  • Loading branch information
CarloLucibello authored Jan 25, 2025
1 parent bea6ec4 commit db92264
Show file tree
Hide file tree
Showing 4 changed files with 42 additions and 18 deletions.
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,4 @@
/Manifest.toml
docs/build
.vscode
/test.jl
/test.jl
41 changes: 27 additions & 14 deletions src/resample.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""
oversample(data, classes; fraction=1, shuffle=true)
oversample(data::Tuple; fraction=1, shuffle=true)
oversample([rng], data, classes; fraction=1, shuffle=true)
oversample([rng], data::Tuple; fraction=1, shuffle=true)
Generate a re-balanced version of `data` by repeatedly sampling
existing observations in such a way that every class will have at
Expand All @@ -21,6 +21,9 @@ resulting data will be shuffled after its creation; if it is not
shuffled then all the repeated samples will be together at the
end, sorted by class. Defaults to `true`.
The random number generator `rng` can be optionally passed as the
first argument.
The output will contain both the resampled data and classes.
```julia
Expand All @@ -44,7 +47,7 @@ X_bal, Y_bal = oversample(X, Y)
For this function to work, the type of `data` must implement
[`numobs`](@ref) and [`getobs`](@ref).
Note that if `data` is a tuple and `classes` is not given,
If `data` is a tuple and `classes` is not given,
then it will be assumed that the last element of the tuple contains the classes.
```julia
Expand Down Expand Up @@ -77,7 +80,10 @@ julia> getobs(oversample(data, data.Y))
See [`ObsView`](@ref) for more information on data subsets.
See also [`undersample`](@ref).
"""
function oversample(data, classes; fraction=1, shuffle::Bool=true)
oversample(data, classes; kws...) = oversample(Random.default_rng(), data, classes; kws...)
oversample(data::Tuple; kws...) = oversample(Random.default_rng(), data; kws...)

function oversample(rng::AbstractRNG, data, classes; fraction=1, shuffle::Bool=true)
lm = group_indices(classes)

maxcount = maximum(length, values(lm))
Expand All @@ -94,24 +100,25 @@ function oversample(data, classes; fraction=1, shuffle::Bool=true)
end
if num_extra_needed > 0
if shuffle
append!(inds, sample(inds_for_lbl, num_extra_needed; replace=false))
append!(inds, sample(rng, inds_for_lbl, num_extra_needed; replace=false))
else
append!(inds, inds_for_lbl[1:num_extra_needed])
end
end
end

shuffle && shuffle!(inds)
shuffle && shuffle!(rng, inds)
return obsview(data, inds), obsview(classes, inds)
end

function oversample(data::Tuple; kws...)
d, c = oversample(data[1:end-1], data[end]; kws...)
function oversample(rng::AbstractRNG, data::Tuple; kws...)
d, c = oversample(rng, data[1:end-1], data[end]; kws...)
return (d..., c)
end

"""
undersample(data, classes; shuffle=true)
undersample([rng], data, classes; shuffle=true)
undersample([rng], data::Tuple; shuffle=true)
Generate a class-balanced version of `data` by subsampling its
observations in such a way that the resulting number of
Expand All @@ -124,6 +131,9 @@ resulting data will be shuffled after its creation; if it is not
shuffled then all the observations will be in their original
order. Defaults to `false`.
If `data` is a tuple and `classes` is not given,
then it will be assumed that the last element of the tuple contains the classes.
The output will contain both the resampled data and classes.
```julia
Expand Down Expand Up @@ -176,25 +186,28 @@ julia> getobs(undersample(data, data.Y))
See [`ObsView`](@ref) for more information on data subsets.
See also [`oversample`](@ref).
"""
function undersample(data, classes; shuffle::Bool=true)
undersample(data, classes; kws...) = undersample(Random.default_rng(), data, classes; kws...)
undersample(data::Tuple; kws...) = undersample(Random.default_rng(), data; kws...)

function undersample(rng::AbstractRNG, data, classes; shuffle::Bool=true)
lm = group_indices(classes)
mincount = minimum(length, values(lm))

inds = Int[]

for (lbl, inds_for_lbl) in lm
if shuffle
append!(inds, sample(inds_for_lbl, mincount; replace=false))
append!(inds, sample(rng, inds_for_lbl, mincount; replace=false))
else
append!(inds, inds_for_lbl[1:mincount])
end
end

shuffle ? shuffle!(inds) : sort!(inds)
shuffle ? shuffle!(rng, inds) : sort!(inds)
return obsview(data, inds), obsview(classes, inds)
end

function undersample(data::Tuple; kws...)
d, c = undersample(data[1:end-1], data[end]; kws...)
function undersample(rng::AbstractRNG, data::Tuple; kws...)
d, c = undersample(rng, data[1:end-1], data[end]; kws...)
return (d..., c)
end
10 changes: 7 additions & 3 deletions src/splitobs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -38,14 +38,16 @@ function _splitobs(n::Int, at::NTuple{N,<:AbstractFloat}) where N
end

"""
splitobs(data; at, shuffle=false) -> Tuple
splitobs([rng], data; at, shuffle=false) -> Tuple
Partition the `data` into two or more subsets.
When `at` is a number (between 0 and 1) this specifies the proportion in the first subset.
When `at` is a tuple, each entry specifies the proportion an a subset,
with the last having `1-sum(at)`. In all there are `length(at)+1` subsets returned.
If `shuffle=true`, randomly permute the observations before splitting.
A random number generator `rng` can be optionally passed as the first argument.
Supports any datatype implementing the [`numobs`](@ref) and
[`getobs`](@ref) interfaces -- including arrays, tuples & NamedTuples of arrays.
Expand All @@ -68,9 +70,11 @@ julia> vec(test[1]) .+ 100 == test[2]
true
```
"""
function splitobs(data; at, shuffle::Bool=false)
splitobs(data; kws...) = splitobs(Random.default_rng(), data; kws...)

function splitobs(rng::AbstractRNG, data; at, shuffle::Bool=false)
if shuffle
data = shuffleobs(data)
data = shuffleobs(rng, data)
end
n = numobs(data)
return map(idx -> obsview(data, idx), splitobs(n; at))
Expand Down
7 changes: 7 additions & 0 deletions test/splitobs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -82,4 +82,11 @@ end
@test s != X[:,1:2]
@test any(s[:,1] == x for x in eachcol(X))
@test any(s[:,2] == x for x in eachcol(X))

data = 1:100
rng = MersenneTwister(1234)
p1, _ = splitobs(rng, data, at=3, shuffle=true)
rng = MersenneTwister(1234)
p2, _ = splitobs(rng, data, at=3, shuffle=true)
@test p1 == p2
end

0 comments on commit db92264

Please sign in to comment.