Skip to content

Commit

Permalink
Restrict ldiv! rules to Cholesky (#1257)
Browse files Browse the repository at this point in the history
* Restrict `ldiv!` rules to `Cholesky`

* Apply suggestions from code review

* Update src/internal_rules.jl

* Apply suggestions from code review
  • Loading branch information
devmotion authored Jan 29, 2024
1 parent 3095a4f commit d4f6400
Showing 1 changed file with 14 additions and 29 deletions.
43 changes: 14 additions & 29 deletions src/internal_rules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -621,10 +621,10 @@ end
function EnzymeRules.forward(
func::Const{typeof(ldiv!)},
RT::Type,
fact::Annotation{C},
B,
fact::Annotation{<:Cholesky},
B;
kwargs...
) where {C <: Union{Cholesky,Array}}
)
if isa(B, Const)
@assert (RT <: Const)
return func.val(fact.val, B.val; kwargs...)
Expand Down Expand Up @@ -656,19 +656,11 @@ function EnzymeRules.forward(
fact.dval[b]
end

if C <: Array
mul!(dB, dfact, retval, -1, 1)
else
tmp = dfact.U * retval

dB .-= dfact.L * tmp

# if mul! was implemented for LU, this would be faster
# mul!(dB, dfact.L, tmp, -1, 1)
end
tmp = dfact.U * retval
mul!(dB, dfact.L, tmp, -1, 1)
end

ldiv!(fact.val, dB; kwargs...)
func.val(fact.val, dB; kwargs...)
end

if RT <: Const
Expand Down Expand Up @@ -750,12 +742,11 @@ function EnzymeRules.augmented_primal(
func::Const{typeof(ldiv!)},
RT::Type{<:Union{Const, DuplicatedNoNeed, Duplicated, BatchDuplicatedNoNeed, BatchDuplicated}},

A::Annotation{AType},
A::Annotation{<:Cholesky},
B::Union{Const, DuplicatedNoNeed, Duplicated, BatchDuplicatedNoNeed, BatchDuplicated};
kwargs...
) where {AType <: Union{Cholesky, Array}}

ldiv!(A.val, B.val; kwargs...)
)
func.val(A.val, B.val; kwargs...)

cache_Bout = if !isa(A, Const) && !isa(B, Const)
if EnzymeRules.overwritten(config)[3]
Expand Down Expand Up @@ -797,11 +788,10 @@ function EnzymeRules.reverse(
func::Const{typeof(ldiv!)},
dret,
cache,
A::Annotation{AType},
B;
A::Annotation{<:Cholesky},
B::Union{Const, DuplicatedNoNeed, Duplicated, BatchDuplicatedNoNeed, BatchDuplicated};
kwargs...
) where {AType <: Union{Cholesky,Array}}

)
if !isa(B, Const)

(cache_A, cache_Bout) = cache
Expand All @@ -813,15 +803,10 @@ function EnzymeRules.reverse(
# dB = z, where z = inv(A^T) dB
# dA −= z B(out)^T

func.val(cache_A, dB, kwargs...)
func.val(cache_A, dB; kwargs...)
if !isa(A, Const)
dA = EnzymeRules.width(config) == 1 ? A.dval : A.dval[b]

if AType <: Array
mul!(dA, dB, transpose(cache_Bout), -1, 1)
else
mul!(dA.factors, dB, transpose(cache_Bout), -1, 1)
end
mul!(dA.factors, dB, transpose(cache_Bout), -1, 1)
end
end
end
Expand Down

0 comments on commit d4f6400

Please sign in to comment.