diff --git a/src/LowerTriangularMatrices/lower_triangular_matrix.jl b/src/LowerTriangularMatrices/lower_triangular_matrix.jl index d04f233cf..08944d2b8 100644 --- a/src/LowerTriangularMatrices/lower_triangular_matrix.jl +++ b/src/LowerTriangularMatrices/lower_triangular_matrix.jl @@ -44,26 +44,11 @@ Base.length(L::LowerTriangularArray) = length(L.data) """$(TYPEDSIGNATURES) Size of a `LowerTriangularArray` defined as size of the flattened array if `as <: AbstractVector` and as if it were a full matrix when `as <: AbstractMatrix`` .""" -function Base.size(L::LowerTriangularArray; as::Type=Vector) - if as<:AbstractVector - return size(L.data) - elseif as<:AbstractMatrix - return matrix_size(L) - else - error("Unknown `as` input type in size(::LowerTriangularArray; as=...).") - end -end - -"""$(TYPEDSIGNATURES) -Size of a `LowerTriangularArray` defined as size of the flattened array if `as <: AbstractVector` -and as if it were a full matrix when `as <: AbstractMatrix`` .""" +Base.size(L::LowerTriangularArray; as::T=Vector) where T = size(L, as) Base.size(L::LowerTriangularArray, i::Integer; as::T=Vector) where T = size(L; as=as)[i] -"""$(TYPEDSIGNATURES) -Size of a expanded `LowerTriangularArray` as if it were a full matrix, -returns `(L.m, L.n, size(L.data)[2:end])``.""" -matrix_size(L::LowerTriangularArray) = (L.m, L.n, size(L.data)[2:end]...) -matrix_size(L::LowerTriangularArray, i::Int) = matrix_size(L)[i] +Base.size(L::LowerTriangularArray, as::Type{<:AbstractMatrix}) = (L.m, L.n, size(L.data)[2:end]...) +Base.size(L::LowerTriangularArray, as::Type{<:AbstractVector}) = size(L.data) # sizeof the underlying data vector Base.sizeof(L::LowerTriangularArray) = sizeof(L.data) @@ -533,7 +518,7 @@ function Base.similar( ::Type{T}, ) where {N, ArrayType, T} L = find_L(bc) - return LowerTriangularArray{T, N, ArrayType{T,N}}(undef, matrix_size(L)) + return LowerTriangularArray{T, N, ArrayType{T,N}}(undef, size(L; as=Matrix)) end # same function as above, but needs to be defined for both CPU and GPU style @@ -542,7 +527,7 @@ function Base.similar( ::Type{T}, ) where {N, ArrayType, T} L = find_L(bc) - return LowerTriangularArray{T, N, ArrayType{T,N}}(undef, matrix_size(L)) + return LowerTriangularArray{T, N, ArrayType{T,N}}(undef, size(L; as=Matrix)) end function GPUArrays.backend(