Skip to content

Commit

Permalink
fix batchseq for vectors of vectors (#189)
Browse files Browse the repository at this point in the history
* batchseq

* test more
  • Loading branch information
CarloLucibello authored Jan 25, 2025
1 parent ea116e6 commit fa23fbb
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 12 deletions.
16 changes: 8 additions & 8 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -429,10 +429,10 @@ julia> batchseq([[1, 2, 3], [4, 5]], 0)
[3, 0]
```
"""
function batchseq(xs, val = 0, n = nothing)
n = n === nothing ? maximum(x -> size(x, ndims(x)), xs) : n
function batchseq(xs, val = 0)
n = maximum(numobs, xs)
xs_ = [rpad_constant(x, n, val; dims=ndims(x)) for x in xs]
[batch([obsview(xs_[j], i) for j = 1:length(xs_)]) for i = 1:n]
return [batch([getobs(xs_[j], i) for j = 1:length(xs_)]) for i = 1:n]
end

"""
Expand Down Expand Up @@ -464,11 +464,11 @@ julia> rpad_constant([1 2; 3 4], 4; dims=1) # padding along the first dimension
0 0
julia> rpad_constant([1 2; 3 4], 4) # padding along all dimensions by default
2 Matrix{Int64}:
1 2
3 4
0 0
0 0
4 Matrix{Int64}:
1 2 0 0
3 4 0 0
0 0 0 0
0 0 0 0
```
"""
function rpad_constant(x::AbstractArray, n::Union{Integer, Tuple}, val=0; dims=:)
Expand Down
18 changes: 14 additions & 4 deletions test/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -193,10 +193,20 @@ end
@test bs[2] == [2, 5]
@test bs[3] == [3, -1]

batchseq([ones(2,4), zeros(2, 3), ones(2,2)]) ==[[1.0 0.0 1.0; 1.0 0.0 1.0]
[1.0 0.0 1.0; 1.0 0.0 1.0]
[1.0 0.0 0.0; 1.0 0.0 0.0]
[1.0 0.0 0.0; 1.0 0.0 0.0]]
bs = batchseq([[ones(3), ones(3), ones(3)], [zeros(3), zeros(3)]], [-1,-1,-1])
@test bs isa Vector{Matrix{Float64}}
@test bs[1] == [1.0 0.0; 1.0 0.0; 1.0 0.0]
@test bs[2] == [1.0 0.0; 1.0 0.0; 1.0 0.0]
@test bs[3] == [1.0 -1.0; 1.0 -1.0; 1.0 -1.0]

bs = batchseq([ones(2,4), zeros(2, 3), ones(2,2)])
@test bs isa Vector{Matrix{Float64}}
@test bs[1] == [1.0 0.0 1.0; 1.0 0.0 1.0]
@test bs[2] == [1.0 0.0 1.0; 1.0 0.0 1.0]
@test bs[3] == [1.0 0.0 0.0; 1.0 0.0 0.0]
@test bs[4] == [1.0 0.0 0.0; 1.0 0.0 0.0]


end

@testset "ones_like" begin
Expand Down

0 comments on commit fa23fbb

Please sign in to comment.