Skip to content

Commit

Permalink
size(LTA) with dispatch
Browse files Browse the repository at this point in the history
  • Loading branch information
maximilian-gelbrecht committed Jun 13, 2024
1 parent a594be7 commit 1379fb9
Showing 1 changed file with 5 additions and 20 deletions.
25 changes: 5 additions & 20 deletions src/LowerTriangularMatrices/lower_triangular_matrix.jl
Original file line number Diff line number Diff line change
Expand Up @@ -44,26 +44,11 @@ Base.length(L::LowerTriangularArray) = length(L.data)
"""$(TYPEDSIGNATURES)
Size of a `LowerTriangularArray` defined as size of the flattened array if `as <: AbstractVector`
and as if it were a full matrix when `as <: AbstractMatrix`` ."""
function Base.size(L::LowerTriangularArray; as::Type=Vector)
if as<:AbstractVector
return size(L.data)
elseif as<:AbstractMatrix
return matrix_size(L)
else
error("Unknown `as` input type in size(::LowerTriangularArray; as=...).")
end
end

"""$(TYPEDSIGNATURES)
Size of a `LowerTriangularArray` defined as size of the flattened array if `as <: AbstractVector`
and as if it were a full matrix when `as <: AbstractMatrix`` ."""
Base.size(L::LowerTriangularArray; as::T=Vector) where T = size(L, as)
Base.size(L::LowerTriangularArray, i::Integer; as::T=Vector) where T = size(L; as=as)[i]

"""$(TYPEDSIGNATURES)
Size of a expanded `LowerTriangularArray` as if it were a full matrix,
returns `(L.m, L.n, size(L.data)[2:end])``."""
matrix_size(L::LowerTriangularArray) = (L.m, L.n, size(L.data)[2:end]...)
matrix_size(L::LowerTriangularArray, i::Int) = matrix_size(L)[i]
Base.size(L::LowerTriangularArray, as::Type{<:AbstractMatrix}) = (L.m, L.n, size(L.data)[2:end]...)
Base.size(L::LowerTriangularArray, as::Type{<:AbstractVector}) = size(L.data)

# sizeof the underlying data vector
Base.sizeof(L::LowerTriangularArray) = sizeof(L.data)
Expand Down Expand Up @@ -533,7 +518,7 @@ function Base.similar(
::Type{T},
) where {N, ArrayType, T}
L = find_L(bc)
return LowerTriangularArray{T, N, ArrayType{T,N}}(undef, matrix_size(L))
return LowerTriangularArray{T, N, ArrayType{T,N}}(undef, size(L; as=Matrix))
end

# same function as above, but needs to be defined for both CPU and GPU style
Expand All @@ -542,7 +527,7 @@ function Base.similar(
::Type{T},
) where {N, ArrayType, T}
L = find_L(bc)
return LowerTriangularArray{T, N, ArrayType{T,N}}(undef, matrix_size(L))
return LowerTriangularArray{T, N, ArrayType{T,N}}(undef, size(L; as=Matrix))
end

function GPUArrays.backend(
Expand Down

0 comments on commit 1379fb9

Please sign in to comment.