Skip to content

Commit

Permalink
implement trues_like, falses_like + vector indexing for joinobs (#184)
Browse files Browse the repository at this point in the history
  • Loading branch information
CarloLucibello authored Jan 25, 2025
1 parent 86ed1e5 commit 5c20c7b
Show file tree
Hide file tree
Showing 5 changed files with 43 additions and 1 deletion.
2 changes: 2 additions & 0 deletions src/MLUtils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ include("utils.jl")
export batch,
batchseq,
chunk,
falses_like,
fill_like,
flatten,
group_counts,
Expand All @@ -76,6 +77,7 @@ export batch,
randn_like,
rpad_constant,
stack, # in Base since julia v1.9
trues_like,
unbatch,
unsqueeze,
unstack,
Expand Down
7 changes: 6 additions & 1 deletion src/obstransform.jl
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,8 @@ JoinedData(datas) = JoinedData(datas, numobs.(datas))

Base.length(data::JoinedData) = sum(data.ns)

function Base.getindex(data::JoinedData, idx)
function Base.getindex(data::JoinedData, idx::Integer)
@assert 1 <= idx <= length(data)
for (i, n) in enumerate(data.ns)
if idx <= n
return getobs(data.datas[i], idx)
Expand All @@ -178,6 +179,10 @@ function Base.getindex(data::JoinedData, idx)
end
end

function Base.getindex(data::JoinedData, idx::AbstractVector{<:Integer})
return [data[i] for i in idx]
end

"""
joinobs(datas...)
Expand Down
19 changes: 19 additions & 0 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -734,6 +734,25 @@ julia> fill_like(x, 1.7, Float64)
fill_like(x::AbstractArray, val, T::Type, sz=size(x)) = fill!(similar(x, T, sz), val)
fill_like(x::AbstractArray, val, sz=size(x)) = fill_like(x, val, eltype(x), sz)

"""
trues_like(x, [dims=size(x)])
Equivalent to `fill_like(x, true, Bool, dims)`.
See also [`fill_like`] and [`falses_like`](@ref).
"""
trues_like(x::AbstractArray, sz=size(x)) = fill_like(x, true, Bool, sz)

"""
falses_like(x, [dims=size(x)])
Equivalent to `fill_like(x, false, Bool, dims)`.
See also [`fill_like`] and [`trues_like`](@ref).
"""
falses_like(x::AbstractArray, sz=size(x)) = fill_like(x, false, Bool, sz)


@non_differentiable zeros_like(::Any...)
@non_differentiable ones_like(::Any...)
@non_differentiable rand_like(::Any...)
Expand Down
5 changes: 5 additions & 0 deletions test/obstransform.jl
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,11 @@ end
data1, data2 = 1:10, 11:20
jdata = joinobs(data1, data2)
@test getobs(jdata, 15) == 15

data = joinobs(1:5, 6:10)
@test data[5:6] == [5, 6]
data = joinobs(ones(2, 3), zeros(2, 3))
@test data[3:4] == [[1.0, 1.0], [0.0, 0.0]]
end

@testset "shuffleobs" begin
Expand Down
11 changes: 11 additions & 0 deletions test/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -242,6 +242,17 @@ end
test_zygote(fill_like, rand(5), rand(), (2, 4, 2))
end

@testset "trues_like and falses_like" begin
x = rand(Float16, 2, 3)
y = trues_like(x, (2, 4, 2))
@test y isa Array{Bool,3}
@test y == trues(2, 4, 2)

y = falses_like(x, (2, 4, 2))
@test y isa Array{Bool,3}
@test y == falses(2, 4, 2)
end

@testset "rpad_constant" begin
@test rpad_constant([1, 2], 4, -1) == [1, 2, -1, -1]
@test rpad_constant([1, 2, 3], 2) == [1, 2, 3]
Expand Down

0 comments on commit 5c20c7b

Please sign in to comment.