From 2cc27e2124fc7746afdd7753fbc9f19c1c24ac13 Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Tue, 3 May 2022 22:07:22 +0200 Subject: [PATCH] Fix exp rules for some matrices (#596) * Avoid un-balancing the needed intermediate * Test exp for imbalanced unsquared matrix * Increment patch number * Apply suggestions from code review Co-authored-by: Frames Catherine White * Add exhaustive tests * Eliminate truncation error * Increment patch number * Increment patch version number Co-authored-by: Frames Catherine White --- Project.toml | 2 +- src/rulesets/LinearAlgebra/matfun.jl | 5 +++- test/rulesets/LinearAlgebra/matfun.jl | 36 +++++++++++++++++++++++++++ 3 files changed, 41 insertions(+), 2 deletions(-) diff --git a/Project.toml b/Project.toml index 3cceb9250..05fde5b0b 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "ChainRules" uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2" -version = "1.28.3" +version = "1.28.4" [deps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" diff --git a/src/rulesets/LinearAlgebra/matfun.jl b/src/rulesets/LinearAlgebra/matfun.jl index 4b41dd565..8ed6c967d 100644 --- a/src/rulesets/LinearAlgebra/matfun.jl +++ b/src/rulesets/LinearAlgebra/matfun.jl @@ -219,6 +219,9 @@ function _matfun!(::typeof(exp), A::StridedMatrix{T}) where {T<:BlasFloat} X *= X push!(Xpows, X) end + else + # Xpows[1] must remain balanced for computing the Fréchet derivative + X = copy(X) end _unbalance!(X, ilo, ihi, scale, n) @@ -247,7 +250,7 @@ function _matfun_frechet!( ∂P = copy(∂A2) ∂W = C[4] * ∂P ∂V = C[3] * ∂P - for k in 2:(length(Apows) - 1) + for k in 2:length(Apows) k2 = 2 * k P = Apows[k - 1] ∂P, ∂temp = mul!(mul!(∂temp, ∂P, A2), P, ∂A2, true, true), ∂P diff --git a/test/rulesets/LinearAlgebra/matfun.jl b/test/rulesets/LinearAlgebra/matfun.jl index 9ee4e0705..03ba58023 100644 --- a/test/rulesets/LinearAlgebra/matfun.jl +++ b/test/rulesets/LinearAlgebra/matfun.jl @@ -14,6 +14,24 @@ A = Float64[0 10 0 0; -1 0 0 0; 0 0 0 0; -2 0 0 0] test_frule(LinearAlgebra.exp!, A) end + @testset "imbalanced A with no squaring" begin + # https://github.com/JuliaDiff/ChainRules.jl/issues/595 + A = [ + -0.007623430669065629 -0.567237096385192 0.4419041897734335; + 2.090838913114862 -1.254084243281689 -0.04145771190198238; + 2.3397892123412833 -0.6650489083959324 0.6387266010923911 + ] + test_frule(LinearAlgebra.exp!, A) + end + @testset "exhaustive test" begin + # added to ensure we never hit truncation error + # https://github.com/JuliaDiff/ChainRules.jl/issues/595 + rng = MersenneTwister(1) + for _ in 1:100 + A = randn(rng, 3, 3) + test_frule(LinearAlgebra.exp!, A) + end + end @testset "hermitian A, T=$T" for T in (Float64, ComplexF64) A = Matrix(Hermitian(randn(T, n, n))) test_frule(LinearAlgebra.exp!, A) @@ -48,6 +66,24 @@ A = Float64[0 10 0 0; -1 0 0 0; 0 0 0 0; -2 0 0 0] test_rrule(exp, A; check_inferred=false) end + @testset "imbalanced A with no squaring" begin + # https://github.com/JuliaDiff/ChainRules.jl/issues/595 + A = [ + -0.007623430669065629 -0.567237096385192 0.4419041897734335; + 2.090838913114862 -1.254084243281689 -0.04145771190198238; + 2.3397892123412833 -0.6650489083959324 0.6387266010923911 + ] + test_rrule(LinearAlgebra.exp, A; check_inferred=false) + end + @testset "exhaustive test" begin + # added to ensure we never hit truncation error + # https://github.com/JuliaDiff/ChainRules.jl/issues/595 + rng = MersenneTwister(1) + for _ in 1:100 + A = randn(rng, 3, 3) + test_rrule(LinearAlgebra.exp, A; check_inferred=false) + end + end @testset "hermitian A, T=$T" for T in (Float64, ComplexF64) A = Matrix(Hermitian(randn(T, n, n))) test_rrule(exp, A; check_inferred=false)