Skip to content

Commit

Permalink
Fix exp rules for some matrices (#596)
Browse files Browse the repository at this point in the history
* Avoid un-balancing the needed intermediate

* Test exp for imbalanced unsquared matrix

* Increment patch number

* Apply suggestions from code review

Co-authored-by: Frames Catherine White <[email protected]>

* Add exhaustive tests

* Eliminate truncation error

* Increment patch number

* Increment patch version number

Co-authored-by: Frames Catherine White <[email protected]>
  • Loading branch information
sethaxen and oxinabox authored May 3, 2022
1 parent 7b5f4d1 commit 2cc27e2
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 2 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "ChainRules"
uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2"
version = "1.28.3"
version = "1.28.4"

[deps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Expand Down
5 changes: 4 additions & 1 deletion src/rulesets/LinearAlgebra/matfun.jl
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,9 @@ function _matfun!(::typeof(exp), A::StridedMatrix{T}) where {T<:BlasFloat}
X *= X
push!(Xpows, X)
end
else
# Xpows[1] must remain balanced for computing the Fréchet derivative
X = copy(X)
end

_unbalance!(X, ilo, ihi, scale, n)
Expand Down Expand Up @@ -247,7 +250,7 @@ function _matfun_frechet!(
∂P = copy(∂A2)
∂W = C[4] * ∂P
∂V = C[3] * ∂P
for k in 2:(length(Apows) - 1)
for k in 2:length(Apows)
k2 = 2 * k
P = Apows[k - 1]
∂P, ∂temp = mul!(mul!(∂temp, ∂P, A2), P, ∂A2, true, true), ∂P
Expand Down
36 changes: 36 additions & 0 deletions test/rulesets/LinearAlgebra/matfun.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,24 @@
A = Float64[0 10 0 0; -1 0 0 0; 0 0 0 0; -2 0 0 0]
test_frule(LinearAlgebra.exp!, A)
end
@testset "imbalanced A with no squaring" begin
# https://github.com/JuliaDiff/ChainRules.jl/issues/595
A = [
-0.007623430669065629 -0.567237096385192 0.4419041897734335;
2.090838913114862 -1.254084243281689 -0.04145771190198238;
2.3397892123412833 -0.6650489083959324 0.6387266010923911
]
test_frule(LinearAlgebra.exp!, A)
end
@testset "exhaustive test" begin
# added to ensure we never hit truncation error
# https://github.com/JuliaDiff/ChainRules.jl/issues/595
rng = MersenneTwister(1)
for _ in 1:100
A = randn(rng, 3, 3)
test_frule(LinearAlgebra.exp!, A)
end
end
@testset "hermitian A, T=$T" for T in (Float64, ComplexF64)
A = Matrix(Hermitian(randn(T, n, n)))
test_frule(LinearAlgebra.exp!, A)
Expand Down Expand Up @@ -48,6 +66,24 @@
A = Float64[0 10 0 0; -1 0 0 0; 0 0 0 0; -2 0 0 0]
test_rrule(exp, A; check_inferred=false)
end
@testset "imbalanced A with no squaring" begin
# https://github.com/JuliaDiff/ChainRules.jl/issues/595
A = [
-0.007623430669065629 -0.567237096385192 0.4419041897734335;
2.090838913114862 -1.254084243281689 -0.04145771190198238;
2.3397892123412833 -0.6650489083959324 0.6387266010923911
]
test_rrule(LinearAlgebra.exp, A; check_inferred=false)
end
@testset "exhaustive test" begin
# added to ensure we never hit truncation error
# https://github.com/JuliaDiff/ChainRules.jl/issues/595
rng = MersenneTwister(1)
for _ in 1:100
A = randn(rng, 3, 3)
test_rrule(LinearAlgebra.exp, A; check_inferred=false)
end
end
@testset "hermitian A, T=$T" for T in (Float64, ComplexF64)
A = Matrix(Hermitian(randn(T, n, n)))
test_rrule(exp, A; check_inferred=false)
Expand Down

2 comments on commit 2cc27e2

@sethaxen
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/59609

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v1.28.4 -m "<description of version>" 2cc27e2124fc7746afdd7753fbc9f19c1c24ac13
git push origin v1.28.4

Please sign in to comment.