Skip to content

Commit

Permalink
Merge branch 'master' into jishnub/eigen
Browse files Browse the repository at this point in the history
  • Loading branch information
jishnub authored Aug 29, 2024
2 parents 5fa3fba + 6f61dc3 commit 38b9c08
Show file tree
Hide file tree
Showing 5 changed files with 145 additions and 13 deletions.
14 changes: 12 additions & 2 deletions ext/FillArraysSparseArraysExt.jl
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
module FillArraysSparseArraysExt

using SparseArrays
using FillArrays
using FillArrays: RectDiagonalFill, RectOrDiagonalFill, ZerosVector, ZerosMatrix, getindex_value
using SparseArrays: SparseVectorUnion
import Base: convert, kron
using FillArrays
using FillArrays: RectDiagonalFill, RectOrDiagonalFill, ZerosVector, ZerosMatrix, getindex_value, AbstractFillVector, _fill_dot
# Specifying the full namespace is necessary because of https://github.com/JuliaLang/julia/issues/48533
# See https://github.com/JuliaStats/LogExpFunctions.jl/pull/63
using FillArrays.LinearAlgebra
import LinearAlgebra: dot, kron, I

##################
## Sparse arrays
Expand Down Expand Up @@ -58,4 +60,12 @@ end
# TODO: remove in v2.0
@deprecate kron(E1::RectDiagonalFill, E2::RectDiagonalFill) kron(sparse(E1), sparse(E2))

# Ambiguity. see #178
if VERSION >= v"1.8"
dot(x::AbstractFillVector, y::SparseVectorUnion) = _fill_dot(x, y)
else
dot(x::AbstractFillVector{<:Number}, y::SparseVectorUnion{<:Number}) = _fill_dot(x, y)
end


end # module
40 changes: 34 additions & 6 deletions src/FillArrays.jl
Original file line number Diff line number Diff line change
Expand Up @@ -375,6 +375,32 @@ fillsimilar(a::Ones{T}, axes...) where T = Ones{T}(axes...)
fillsimilar(a::Zeros{T}, axes...) where T = Zeros{T}(axes...)
fillsimilar(a::AbstractFill, axes...) = Fill(getindex_value(a), axes...)

# functions
function Base.sqrt(a::AbstractFillMatrix{<:Union{Real, Complex}})
Base.require_one_based_indexing(a)
size(a,1) == size(a,2) || throw(DimensionMismatch("matrix is not square: dimensions are $(size(a))"))
_sqrt(a)
end
_sqrt(a::AbstractZerosMatrix) = float(a)
function _sqrt(a::AbstractFillMatrix)
n = size(a,1)
n == 0 && return float(a)
v = getindex_value(a)
Fill((v/n), axes(a))
end
function Base.cbrt(a::AbstractFillMatrix{<:Real})
Base.require_one_based_indexing(a)
size(a,1) == size(a,2) || throw(DimensionMismatch("matrix is not square: dimensions are $(size(a))"))
_cbrt(a)
end
_cbrt(a::AbstractZerosMatrix) = float(a)
function _cbrt(a::AbstractFillMatrix)
n = size(a,1)
n == 0 && return float(a)
v = getindex_value(a)
Fill(cbrt(v)/cbrt(n)^2, axes(a))
end

struct RectDiagonal{T,V<:AbstractVector{T},Axes<:Tuple{Vararg{AbstractUnitRange,2}}} <: AbstractMatrix{T}
diag::V
axes::Axes
Expand Down Expand Up @@ -529,11 +555,13 @@ for (Typ, funcs, func) in ((:AbstractZeros, :zeros, :zero), (:AbstractOnes, :one
end
end

# temporary patch. should be a PR(#48895) to LinearAlgebra
Diagonal{T}(A::AbstractFillMatrix) where T = Diagonal{T}(diag(A))
function convert(::Type{T}, A::AbstractFillMatrix) where T<:Diagonal
checksquare(A)
isdiag(A) ? T(A) : throw(InexactError(:convert, T, A))
if VERSION < v"1.11-"
# temporary patch. should be a PR(#48895) to LinearAlgebra
Diagonal{T}(A::AbstractFillMatrix) where T = Diagonal{T}(diag(A))
function convert(::Type{T}, A::AbstractFillMatrix) where T<:Diagonal
checksquare(A)
isdiag(A) ? T(diag(A)) : throw(InexactError(:convert, T, A))
end
end

Base.StepRangeLen(F::AbstractFillVector{T}) where T = StepRangeLen(getindex_value(F), zero(T), length(F))
Expand Down Expand Up @@ -608,7 +636,7 @@ diff(x::AbstractFillVector{T}) where T = Zeros{T}(length(x)-1)
# unique
#########

unique(x::AbstractFill{T}) where T = isempty(x) ? T[] : T[getindex_value(x)]
unique(x::AbstractFill) = fillsimilar(x, Int(!isempty(x)))
allunique(x::AbstractFill) = length(x) < 2

#########
Expand Down
4 changes: 3 additions & 1 deletion src/fillalgebra.jl
Original file line number Diff line number Diff line change
Expand Up @@ -407,7 +407,9 @@ dot(a::AbstractVector, b::AbstractFillVector) = _fill_dot_rev(a, b)
function dot(u::AbstractVector, E::Eye, v::AbstractVector)
length(u) == size(E,1) && length(v) == size(E,2) ||
throw(DimensionMismatch("dot product arguments have dimensions $(length(u))×$(size(E))×$(length(v))"))
dot(u, v)
d = dot(u,v)
T = typeof(one(eltype(E)) * d)
convert(T, d)
end

function dot(u::AbstractVector, D::Diagonal{<:Any,<:Fill}, v::AbstractVector)
Expand Down
13 changes: 13 additions & 0 deletions src/oneelement.jl
Original file line number Diff line number Diff line change
Expand Up @@ -387,6 +387,11 @@ function triu(A::OneElementMatrix, k::Integer=0)
OneElement(nzband < k ? zero(A.val) : A.val, A.ind, axes(A))
end


# issymmetric
issymmetric(O::OneElement) = axes(O,1) == axes(O,2) && isdiag(O) && issymmetric(getindex_value(O))
ishermitian(O::OneElement) = axes(O,1) == axes(O,2) && isdiag(O) && ishermitian(getindex_value(O))

# diag
function diag(O::OneElementMatrix, k::Integer=0)
Base.require_one_based_indexing(O)
Expand Down Expand Up @@ -436,6 +441,14 @@ permutedims(o::OneElementMatrix) = OneElement(o.val, reverse(o.ind), reverse(o.a
permutedims(o::OneElementVector) = reshape(o, (1, length(o)))
permutedims(o::OneElement, dims) = OneElement(o.val, _permute(o.ind, dims), _permute(o.axes, dims))

# unique
function unique(O::OneElement)
v = getindex_value(O)
len = iszero(v) ? 1 : min(2, length(O))
OneElement(getindex_value(O), len, len)
end
allunique(O::OneElement) = length(O) <= 1 || (length(O) < 3 && !iszero(getindex_value(O)))

# show
_maybesize(t::Tuple{Base.OneTo{Int}, Vararg{Base.OneTo{Int}}}) = size.(t,1)
_maybesize(t) = t
Expand Down
87 changes: 83 additions & 4 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1245,7 +1245,7 @@ end
@testset "unique" begin
@test unique(Fill(12, 20)) == unique(fill(12, 20))
@test unique(Fill(1, 0)) == []
@test unique(Zeros(0)) isa Vector{Float64}
@test unique(Zeros(0)) == Zeros(0)
@test !allunique(Fill("a", 2))
@test allunique(Ones(0))
end
Expand Down Expand Up @@ -1438,9 +1438,14 @@ end
@test axes(E .+ E) === axes(E)
end

@testset "Issue #31" begin
@test convert(SparseMatrixCSC{Float64,Int64}, Zeros{Float64}(3, 3)) == spzeros(3, 3)
@test sparse(Zeros(4, 2)) == spzeros(4, 2)
@testset "Issues" begin
@testset "#31" begin
@test convert(SparseMatrixCSC{Float64,Int64}, Zeros{Float64}(3, 3)) == spzeros(3, 3)
@test sparse(Zeros(4, 2)) == spzeros(4, 2)
end
@testset "#178" begin
@test Zeros(10)'*spzeros(10) == 0
end
end

@testset "Adjoint/Transpose/permutedims" begin
Expand Down Expand Up @@ -1943,6 +1948,8 @@ end
@test dot(Fill(2,N),1:N) == dot(Fill(2,N),1:N) == dot(1:N,Fill(2,N)) == 2*sum(1:N)
end

@test dot(1:4, Eye(4), 1:4) === dot(1:4, oneunit(eltype(Eye(4))) * I(4), 1:4)

@test_throws DimensionMismatch dot(u[1:end-1], D, v)
@test_throws DimensionMismatch dot(u[1:end-1], D, v[1:end-1])

Expand Down Expand Up @@ -2708,6 +2715,42 @@ end
@test repr(B) == "OneElement(2, (1, 2), (Base.IdentityUnitRange(1:1), Base.IdentityUnitRange(2:2)))"
end

@testset "issymmetric/ishermitian" begin
for el in (2, 3+0im, 4+5im, SMatrix{2,2}(1:4), SMatrix{2,3}(1:6)), size in [(3,3), (3,4)]
O = OneElement(el, (2,2), size)
A = Array(O)
@test issymmetric(O) == issymmetric(A)
@test ishermitian(O) == ishermitian(A)
O = OneElement(el, (1,2), size)
A = Array(O)
@test issymmetric(O) == issymmetric(A)
@test ishermitian(O) == ishermitian(A)
O = OneElement(el, (5,5), size)
A = Array(O)
@test issymmetric(O) == issymmetric(A)
@test ishermitian(O) == ishermitian(A)
end
end

@testset "unique" begin
@testset for n in 1:3
O = OneElement(5, 2, n)
@test unique(O) == unique(Array(O))
@test allunique(O) == allunique(Array(O))
O = OneElement(0, 2, n)
@test unique(O) == unique(Array(O))
@test allunique(O) == allunique(Array(O))
@testset for m in 1:4
O2 = OneElement(2, (2,1), (m,n))
@test unique(O2) == unique(Array(O2))
@test allunique(O2) == allunique(Array(O2))
O2 = OneElement(0, (2,1), (m,n))
@test unique(O2) == unique(Array(O2))
@test allunique(O2) == allunique(Array(O2))
end
end
end

@testset "sum" begin
@testset "OneElement($v, $ind, $sz)" for (v, ind, sz) in (
(Int8(2), 3, 4),
Expand Down Expand Up @@ -2966,3 +3009,39 @@ end
end
end
end

@testset "Diagonal conversion (#389)" begin
@test convert(Diagonal{Int, Vector{Int}}, Zeros(5,5)) isa Diagonal{Int,Vector{Int}}
@test convert(Diagonal{Int, Vector{Int}}, Zeros(5,5)) == zeros(5,5)
@test Diagonal{Int}(Zeros(5,5)) Diagonal(Zeros{Int}(5))
@test Diagonal{Int}(Ones(5,5)) Diagonal(Ones{Int}(5))
end

@testset "sqrt/cbrt" begin
F = Fill(4, 4, 4)
A = Array(F)
@test sqrt(F) sqrt(A) rtol=3e-8
@test sqrt(F)^2 F
F = Fill(4+4im, 4, 4)
A = Array(F)
@test sqrt(F) sqrt(A) rtol=1e-8
@test sqrt(F)^2 F
F = Fill(-4, 4, 4)
A = Array(F)
if VERSION >= v"1.11.0-rc3"
@test cbrt(F) cbrt(A) rtol=1e-5
end
@test cbrt(F)^3 F

# avoid overflow
F = Fill(4, typemax(Int), typemax(Int))
@test sqrt(F)^2 F
@test cbrt(F)^3 F

# zeros
F = Zeros(4, 4)
A = Array(F)
@test sqrt(F) sqrt(A) atol=1e-14
@test sqrt(F)^2 == F
@test cbrt(F)^3 == F
end

0 comments on commit 38b9c08

Please sign in to comment.