diff --git a/Project.toml b/Project.toml index f22c04ab..4ec9f49a 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "FillArrays" uuid = "1a297f60-69ca-5386-bcde-b61e274b549b" -version = "1.12.0" +version = "1.13.0" [deps] LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" diff --git a/src/FillArrays.jl b/src/FillArrays.jl index 170e07d3..4d0b4e29 100644 --- a/src/FillArrays.jl +++ b/src/FillArrays.jl @@ -374,6 +374,32 @@ fillsimilar(a::Ones{T}, axes...) where T = Ones{T}(axes...) fillsimilar(a::Zeros{T}, axes...) where T = Zeros{T}(axes...) fillsimilar(a::AbstractFill, axes...) = Fill(getindex_value(a), axes...) +# functions +function Base.sqrt(a::AbstractFillMatrix{<:Union{Real, Complex}}) + Base.require_one_based_indexing(a) + size(a,1) == size(a,2) || throw(DimensionMismatch("matrix is not square: dimensions are $(size(a))")) + _sqrt(a) +end +_sqrt(a::AbstractZerosMatrix) = float(a) +function _sqrt(a::AbstractFillMatrix) + n = size(a,1) + n == 0 && return float(a) + v = getindex_value(a) + Fill(√(v/n), axes(a)) +end +function Base.cbrt(a::AbstractFillMatrix{<:Real}) + Base.require_one_based_indexing(a) + size(a,1) == size(a,2) || throw(DimensionMismatch("matrix is not square: dimensions are $(size(a))")) + _cbrt(a) +end +_cbrt(a::AbstractZerosMatrix) = float(a) +function _cbrt(a::AbstractFillMatrix) + n = size(a,1) + n == 0 && return float(a) + v = getindex_value(a) + Fill(cbrt(v)/cbrt(n)^2, axes(a)) +end + struct RectDiagonal{T,V<:AbstractVector{T},Axes<:Tuple{Vararg{AbstractUnitRange,2}}} <: AbstractMatrix{T} diag::V axes::Axes diff --git a/test/runtests.jl b/test/runtests.jl index 0163df8a..122586c4 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -2872,3 +2872,32 @@ end @test triu(Z, 2) === Z @test tril(Z, 2) === Z end + +@testset "sqrt/cbrt" begin + F = Fill(4, 4, 4) + A = Array(F) + @test sqrt(F) ≈ sqrt(A) rtol=3e-8 + @test sqrt(F)^2 ≈ F + F = Fill(4+4im, 4, 4) + A = Array(F) + @test sqrt(F) ≈ sqrt(A) rtol=1e-8 + @test sqrt(F)^2 ≈ F + F = Fill(-4, 4, 4) + A = Array(F) + if VERSION >= v"1.11.0-rc3" + @test cbrt(F) ≈ cbrt(A) rtol=1e-5 + end + @test cbrt(F)^3 ≈ F + + # avoid overflow + F = Fill(4, typemax(Int), typemax(Int)) + @test sqrt(F)^2 ≈ F + @test cbrt(F)^3 ≈ F + + # zeros + F = Zeros(4, 4) + A = Array(F) + @test sqrt(F) ≈ sqrt(A) atol=1e-14 + @test sqrt(F)^2 == F + @test cbrt(F)^3 == F +end