Skip to content

Commit

Permalink
eachmatrix
Browse files Browse the repository at this point in the history
  • Loading branch information
milankl committed Jul 4, 2024
1 parent 1379fb9 commit 575529c
Show file tree
Hide file tree
Showing 3 changed files with 81 additions and 21 deletions.
2 changes: 1 addition & 1 deletion ext/SpeedyWeatherJLArraysExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ module SpeedyWeatherJLArraysExt
using SpeedyWeather, JLArrays

# for RingGrids and LowerTriangularMatrices:
# every Array needs this method to strip away the parameters
# every Array needs this method to strip away the parameters
SpeedyWeather.RingGrids.nonparametric_type(::Type{<:JLArray}) = JLArray
SpeedyWeather.LowerTriangularMatrices.nonparametric_type(::Type{<:JLArray}) = JLArray

Expand Down
2 changes: 1 addition & 1 deletion src/LowerTriangularMatrices/LowerTriangularMatrices.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ import UnicodePlots
# export plot

export LowerTriangularMatrix, LowerTriangularArray
export eachharmonic
export eachharmonic, eachmatrix

include("lower_triangular_matrix.jl")
include("plot.jl")
Expand Down
98 changes: 79 additions & 19 deletions src/LowerTriangularMatrices/lower_triangular_matrix.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,11 @@ end
check_lta_input_array(data, m, n, N) =
(ndims(data) == N) & (length(data) == prod(size(data)[2:end]) * nonzeros(m, n))

matrix_size(data::AbstractArray, m::Integer, n::Integer) = (m, n, size(data)[2:end]...)

function lta_error_message(data, m, n, T, N, ArrayType)
size_tuple = (m, n, size(data[2:end])...)
return "$(size(data))-sized $(typeof(data)) cannot be used to create "*
"a $size_tuple LowerTriangularArray{$T,$N,$ArrayType}"
return "$(size2x_string(size(data)))-sized $(typeof(data)) cannot be used to create "*
"a $(size2x_string(matrix_size(data, m, n))) LowerTriangularArray{$T, $N, $ArrayType}"
end

"""2-dimensional `LowerTriangularArray` of type `T`` with its non-zero entries unravelled into a `Vector{T}`"""
Expand All @@ -47,8 +48,8 @@ and as if it were a full matrix when `as <: AbstractMatrix`` ."""
Base.size(L::LowerTriangularArray; as::T=Vector) where T = size(L, as)
Base.size(L::LowerTriangularArray, i::Integer; as::T=Vector) where T = size(L; as=as)[i]

Base.size(L::LowerTriangularArray, as::Type{<:AbstractMatrix}) = (L.m, L.n, size(L.data)[2:end]...)
Base.size(L::LowerTriangularArray, as::Type{<:AbstractVector}) = size(L.data)
Base.size(L::LowerTriangularArray, as::Type{Matrix}) = matrix_size(L.data, L.m, L.n)
Base.size(L::LowerTriangularArray, as::Type{Vector}) = size(L.data)

# sizeof the underlying data vector
Base.sizeof(L::LowerTriangularArray) = sizeof(L.data)
Expand Down Expand Up @@ -163,7 +164,9 @@ Angeletti et al, 2019, https://hal.science/hal-02047514/document)
@inline function k2ij(k::Integer, m::Integer)
kp = triangle_number(m) - k
p = Int(floor((sqrt(1 + 8*kp) - 1)/2))
(k - m*(m-1)÷2 + p*(p+1)÷2, m - p)
i = k - m*(m-1)÷2 + p*(p+1)÷2
j = m - p
return i, j
end
k2ij(I::CartesianIndex, m::Int) = CartesianIndex(k2ij(I[1], m)...,I.I[2:end]...)

Expand Down Expand Up @@ -207,9 +210,9 @@ Base.@propagate_inbounds Base.getindex(L::LowerTriangularArray{T,1,V}, i::Intege

# setindex with il, im, ..
@inline function Base.setindex!(L::LowerTriangularArray{T,N}, x, I::Vararg{Any, M}) where {T, N, M}
@boundscheck N+1==M || throw(BoundsError(L,I))
@boundscheck N+1==M || throw(BoundsError(L, I))
i, j = I[1:2]
@boundscheck i >= j || throw(BoundsError(L, (i, j)))
@boundscheck i >= j || throw(BoundsError(L, I))
k = ij2k(i, j, L.m)
setindex!(L.data, x, k, I[3:end]...)
end
Expand Down Expand Up @@ -244,10 +247,67 @@ creates `unit_range::UnitRange` to loop over all non-zeros in the LowerTriangula
provided as arguments. Checks bounds first. All LowerTriangularMatrix's need to be of the same size.
Like `eachindex` but skips the upper triangle with zeros in `L`."""
function eachharmonic(L1::LowerTriangularArray, Ls::LowerTriangularArray...)
n = size(L1.data,1)
Base._all_match_first(L->size(L.data,1), n, L1, Ls...) || throw(BoundsError)
lowertriangular_match(L1, Ls...; horizontal_only=true) || throw(DimensionMismatch(L1, Ls...))
return eachharmonic(L1)
end
end

"""$(TYPEDSIGNATURES) Iterator for the non-horizontal dimensions in
LowerTriangularArrays. To be used like
for k in eachmatrix(L)
L[1, k]
to loop over every non-horizontal dimension of L."""
eachmatrix(L::LowerTriangularArray) = CartesianIndices(size(L)[2:end])

"""$(TYPEDSIGNATURES) Iterator for the non-horizontal dimensions in
LowerTriangularArrays. Checks that the LowerTriangularArrays match according to
`lowertriangular_match`."""
function eachmatrix(L1::LowerTriangularArray, Ls::LowerTriangularArray...)
lowertriangular_match(L1, Ls...) || throw(DimensionMismatch(L1, Ls...))
return eachmatrix(L1)
end

"""$(TYPEDSIGNATURES) True if both `L1` and `L2` are of the same size (as matrix),
but ignores singleton dimensions, e.g. 5x5 and 5x5x1 would match.
With `horizontal_only=true` (default `false`) ignore the non-horizontal dimensions,
e.g. 5x5, 5x5x1, 5x5x2 would all match."""
function lowertriangular_match(
L1::LowerTriangularArray,
L2::LowerTriangularArray;
horizontal_only::Bool=false,
)
horizontal_match = size(L1, as=Matrix)[1:2] == size(L2, as=Matrix)[1:2]
horizontal_only && return horizontal_match
return horizontal_match && length(L1) == length(L2) # ignores singleton dimensions
end


"""$(TYPEDSIGNATURES) True if all lower triangular matrices provided as arguments
match according to `lowertriangular_match` wrt to `L1` (and therefore all)."""
function lowertriangular_match(L1::LowerTriangularArray, Ls::LowerTriangularArray...; kwargs...)
length(Ls) == 0 && return true # single L1 always matches itself
# cut the Ls tuple short on every iteration of the recursion
return lowertriangular_match(L1, Ls[1]; kwargs...) && lowertriangular_match(L1, Ls[2:end]...; kwargs...)
end

"""$(TYPEDSIGNATURES)
Returns a tuple like (1,2,3) as string "1×2×3". To be used with size2x_string(size()"""
function size2x_string(t::Tuple)
s = "$(t[1])"
for i in t[2:end]
s *= "×$i"
end
return s
end

function Base.DimensionMismatch(L1::LowerTriangularArray, Ls::LowerTriangularArray...)
s = "LowerTriangularArrays do not match; $(size2x_string(size(L1, as=Matrix)))"
for L in Ls
s *= ", $(size2x_string(size(L, as=Matrix)))"
end
return DimensionMismatch(s)
end

# CONVERSIONS
"""
Expand All @@ -268,7 +328,7 @@ function LowerTriangularMatrix(M::Matrix{T}) where T # CPU version
end

# helper function for conversion etc on GPU, returns indices of the lower triangle
lowertriangle_indices(M::AbstractMatrix) = tril!(trues(size(M)))
lowertriangle_indices(M::AbstractMatrix) = lowertriangle_indices(size(M)...)
lowertriangle_indices(m::Integer, n::Integer) = tril!(trues((m,n)))

function lowertriangle_indices(M::AbstractArray{T, N}) where {T, N}
Expand Down Expand Up @@ -428,28 +488,28 @@ function Base.convert(
::Type{LowerTriangularArray{T1, N, ArrayTypeT1}},
L::LowerTriangularArray{T2, N, ArrayTypeT2},
) where {T1, T2, N, ArrayTypeT1<:AbstractArray{T1}, ArrayTypeT2<:AbstractArray{T2}}
return LowerTriangularArray{T1,N,ArrayTypeT1}(L.data, L.m, L.n)
return LowerTriangularArray{T1, N, ArrayTypeT1}(L.data, L.m, L.n)
end

function Base.convert(::Type{LowerTriangularMatrix{T}}, L::LowerTriangularMatrix) where T
return LowerTriangularMatrix{T}(L.data, L.m, L.n)
end

function Base.similar(::LowerTriangularArray{T,N,ArrayType}, I::Integer...) where {T, N, ArrayType}
function Base.similar(::LowerTriangularArray{T, N, ArrayType}, I::Integer...) where {T, N, ArrayType}
return LowerTriangularArray{T,N,ArrayType}(undef, I...)
end

function Base.similar(::LowerTriangularArray{T,N,ArrayType}, size::S) where {T, N, ArrayType, S<:Tuple}
function Base.similar(::LowerTriangularArray{T, N, ArrayType}, size::S) where {T, N, ArrayType, S<:Tuple}
return LowerTriangularArray{T,N,ArrayType}(undef, size...)
end

function Base.similar(L::LowerTriangularArray{S,N,ArrayType}, ::Type{T}) where {T, S, N, ArrayType}
function Base.similar(L::LowerTriangularArray{S, N, ArrayType}, ::Type{T}) where {T, S, N, ArrayType}
ArrayType_ = nonparametric_type(ArrayType) # TODO: not sure how else to infer this type
return LowerTriangularArray{T,N,ArrayType_{T,N}}(undef, size(L; as=Matrix)...)
return LowerTriangularArray{T, N, ArrayType_{T, N}}(undef, size(L; as=Matrix)...)
end

Base.similar(L::LowerTriangularArray{T,N,ArrayType}, ::Type{T}) where {T, N, ArrayType} =
LowerTriangularArray{T,N,ArrayType}(undef, size(L; as=Matrix)...)
Base.similar(L::LowerTriangularArray{T, N, ArrayType}, ::Type{T}) where {T, N, ArrayType} =
LowerTriangularArray{T, N, ArrayType}(undef, size(L; as=Matrix)...)
Base.similar(L::LowerTriangularArray{T}) where T = similar(L, T)

Base.prod(L::LowerTriangularArray{NF}) where NF = zero(NF)
Expand Down

0 comments on commit 575529c

Please sign in to comment.