diff --git a/src/rulesets/LinearAlgebra/factorization.jl b/src/rulesets/LinearAlgebra/factorization.jl index 910dd744b..3a3a1f821 100644 --- a/src/rulesets/LinearAlgebra/factorization.jl +++ b/src/rulesets/LinearAlgebra/factorization.jl @@ -216,7 +216,7 @@ end ##### function _svd_pullback(Ȳ::Tangent, F) - ∂X = svd_rev(F, Ȳ.U, Ȳ.S, Ȳ.Vt') + ∂X = svd_rev(F, Ȳ.U, Ȳ.S, Ȳ.Vt) return (NoTangent(), ∂X) end _svd_pullback(Ȳ::AbstractThunk, F) = _svd_pullback(unthunk(Ȳ), F) @@ -243,39 +243,41 @@ function rrule(::typeof(getproperty), F::T, x::Symbol) where T <: SVD return getproperty(F, x), getproperty_svd_pullback end -# When not `ZeroTangent`s expect `Ū::AbstractMatrix, s̄::AbstractVector, V̄::AbstractMatrix` -function svd_rev(USV::SVD, Ū, s̄, V̄) +# When not `ZeroTangent`s expect `Ū::AbstractMatrix, s̄::AbstractVector, V̄t::AbstractMatrix` +function svd_rev(USV::SVD, Ū, s̄, V̄t) # Note: assuming a thin factorization, i.e. svd(A, full=false), which is the default U = USV.U s = USV.S - V = USV.V Vt = USV.Vt k = length(s) T = eltype(s) - F = T[i == j ? 1 : inv(@inbounds s[j]^2 - s[i]^2) for i = 1:k, j = 1:k] - - # We do a lot of matrix operations here, so we'll try to be memory-friendly and do - # as many of the computations in-place as possible. Benchmarking shows that the in- - # place functions here are significantly faster than their out-of-place, naively - # implemented counterparts, and allocate no additional memory. - Ut = U' - FUᵀŪ = _mulsubtrans!!(Ut*Ū, F) # F .* (UᵀŪ - ŪᵀU) - FVᵀV̄ = _mulsubtrans!!(Vt*V̄, F) # F .* (VᵀV̄ - V̄ᵀV) - ImUUᵀ = _eyesubx!(U*Ut) # I - UUᵀ - ImVVᵀ = _eyesubx!(V*Vt) # I - VVᵀ - - S = Diagonal(s) - S̄ = s̄ isa AbstractZero ? s̄ : Diagonal(s̄) - - # TODO: consider using MuladdMacro here - Ā = add!!(U * FUᵀŪ * S, ImUUᵀ * (Ū / S)) * Vt - Ā = add!!(Ā, U * S̄ * Vt) - Ā = add!!(Ā, U * add!!(S * FVᵀV̄ * Vt, (S \ V̄') * ImVVᵀ)) - + UtŪ = U' * Ū + V̄tV = V̄t * Vt' + M = @inbounds T[ + if i == j + s̄[i] + else + (s[j] * (UtŪ[i, j] - UtŪ[j, i]) + s[i] * (V̄tV[j, i] - V̄tV[i, j])) / + (s[j]^2 - s[i]^2) + end for i in 1:k, j in 1:k + ] + + if size(Vt, 1) == size(Vt, 2) + # V is square, VVᵀ = I and therefore V̄ᵀ - V̄ᵀVVᵀ = 0 + Ā = (U * M .+ ((Ū .- U * UtŪ) ./ s')) * Vt + else + # If V is not square then U is, so UUᵀ == I and Ū - UUᵀŪ = 0 + Ā = U * (M * Vt .+ ((V̄t .- V̄tV * Vt) ./ s)) + end return Ā end +function svd_rev(USV::SVD, ::AbstractZero, s̄::AbstractVector, ::AbstractZero) + Ā = USV.U * Diagonal(s̄) * USV.Vt + return Ā +end + ##### ##### `eigen` ##### diff --git a/src/rulesets/LinearAlgebra/utils.jl b/src/rulesets/LinearAlgebra/utils.jl index 3d8ad923f..f13758f0e 100644 --- a/src/rulesets/LinearAlgebra/utils.jl +++ b/src/rulesets/LinearAlgebra/utils.jl @@ -1,33 +1,6 @@ # Some utility functions for optimizing linear algebra operations that aren't specific # to any particular rule definition -# F .* (X - X'), overwrites X if possible -function _mulsubtrans!!(X::AbstractMatrix{<:Real}, F::AbstractMatrix{<:Real}) - T = promote_type(eltype(X), eltype(F)) - Y = (T <: eltype(X)) ? X : similar(X, T) - k = size(X, 1) - @inbounds for j = 1:k, i = 1:j # Iterate the upper triangle - if i == j - Y[i,i] = zero(T) - else - Y[i,j], Y[j,i] = F[i,j] * (X[i,j] - X[j,i]), F[j,i] * (X[j,i] - X[i,j]) - end - end - return Y -end -_mulsubtrans!!(X::AbstractZero, F::AbstractZero) = X -_mulsubtrans!!(X::AbstractZero, F::AbstractMatrix{<:Real}) = X -_mulsubtrans!!(X::AbstractMatrix{<:Real}, F::AbstractZero) = F - -# I - X, overwrites X -function _eyesubx!(X::AbstractMatrix) - n, m = size(X) - @inbounds for j = 1:m, i = 1:n - X[i,j] = (i == j) - X[i,j] - end - return X -end - _extract_imag(x) = complex(0, imag(x)) """ diff --git a/test/rulesets/LinearAlgebra/factorization.jl b/test/rulesets/LinearAlgebra/factorization.jl index 60e2e74be..3832f63d5 100644 --- a/test/rulesets/LinearAlgebra/factorization.jl +++ b/test/rulesets/LinearAlgebra/factorization.jl @@ -146,18 +146,6 @@ end @test dX_thunked == dX_unthunked end end - - @testset "Helper functions" begin - X = randn(10, 10) - Y = randn(10, 10) - @test ChainRules._mulsubtrans!!(copy(X), Y) ≈ Y .* (X - X') - @test ChainRules._eyesubx!(copy(X)) ≈ I - X - - Z = randn(Float32, 10, 10) - result = ChainRules._mulsubtrans!!(copy(Z), Y) - @test result ≈ Y .* (Z - Z') - @test eltype(result) == Float64 - end end @testset "eigendecomposition" begin