diff --git a/src/internal_rules.jl b/src/internal_rules.jl index d626de6685..d073a78c47 100644 --- a/src/internal_rules.jl +++ b/src/internal_rules.jl @@ -482,3 +482,78 @@ function EnzymeRules.reverse(config, func::Const{typeof(Base.hvcat_fill!)}, ::Ty return (nothing, nothing) end end + +function EnzymeRules.forward( + ::Const{typeof(sort!)}, + RT::Type{<:Union{Const, DuplicatedNoNeed, Duplicated}}, + xs::Duplicated; + kwargs... + ) + inds = sortperm(xs.val; kwargs...) + xs.val .= xs.val[inds] + xs.dval .= xs.dval[inds] + if RT <: Const + return xs.val + elseif RT <: DuplicatedNoNeed + return xs.dval + else + return xs + end +end + +function EnzymeRules.forward( + ::Const{typeof(sort!)}, + RT::Type{<:Union{Const, BatchDuplicatedNoNeed, BatchDuplicated}}, + xs::BatchDuplicated{T, N}; + kwargs... + ) where {T, N} + inds = sortperm(xs.val; kwargs...) + xs.val .= xs.val[inds] + for i in 1:N + xs.dval[i] .= xs.dval[i][inds] + end + if RT <: Const + return xs.val + elseif RT <: BatchDuplicatedNoNeed + return xs.dval + else + return xs + end +end + +function EnzymeRules.augmented_primal( + config::EnzymeRules.ConfigWidth{1}, + ::Const{typeof(sort!)}, + RT::Type{<:Union{Const, DuplicatedNoNeed, Duplicated}}, + xs::Duplicated; + kwargs... + ) + inds = sortperm(xs.val; kwargs...) + xs.val .= xs.val[inds] + xs.dval .= xs.dval[inds] + if EnzymeRules.needs_primal(config) + primal = xs.val + else + primal = nothing + end + if RT <: Const + shadow = nothing + else + shadow = xs.dval + end + return EnzymeRules.AugmentedReturn(primal, shadow, inds) +end + +function EnzymeRules.reverse( + config::EnzymeRules.ConfigWidth{1}, + ::Const{typeof(sort!)}, + RT::Type{<:Union{Const, DuplicatedNoNeed, Duplicated}}, + tape, + xs::Duplicated; + kwargs..., + ) + inds = tape + back_inds = sortperm(inds) + xs.dval .= xs.dval[back_inds] + return (nothing,) +end diff --git a/test/internal_rules.jl b/test/internal_rules.jl new file mode 100644 index 0000000000..ccf61fef25 --- /dev/null +++ b/test/internal_rules.jl @@ -0,0 +1,86 @@ +module InternalRules + +using Enzyme +using Enzyme.EnzymeRules +using Test + +@testset "Internal rules" begin + function f1(x) + a = [1.0, 3.0, x] + sort!(a) + return a[2] + end + + @test autodiff(Forward, f1, Duplicated(2.0, 1.0))[1] == 1 + @test autodiff(Forward, f1, BatchDuplicated(2.0, (1.0, 2.0)))[1] == (var"1"=1.0, var"2"=2.0) + @test autodiff(Reverse, f1, Active, Active(2.0))[1][1] == 1 + @test autodiff(Forward, f1, Duplicated(4.0, 1.0))[1] == 0 + @test autodiff(Forward, f1, BatchDuplicated(4.0, (1.0, 2.0)))[1] == (var"1"=0.0, var"2"=0.0) + @test autodiff(Reverse, f1, Active, Active(4.0))[1][1] == 0 + + function f2(x) + a = [1.0, -3.0, -x, -2x, x] + sort!(a; rev=true, lt=(x, y) -> abs(x) < abs(y) || (abs(x) == abs(y) && x < y)) + return sum(a .* [1, 2, 3, 4, 5]) + end + + @test autodiff(Forward, f2, Duplicated(2.0, 1.0))[1] == -3 + @test autodiff(Forward, f2, BatchDuplicated(2.0, (1.0, 2.0)))[1] == (var"1"=-3.0, var"2"=-6.0) + @test autodiff(Reverse, f2, Active, Active(2.0))[1][1] == -3 +end + +@testset "Linear Solve" begin + A = Float64[2 3; 5 7] + dA = zero(A) + b = Float64[11, 13] + db = zero(b) + + forward, pullback = Enzyme.autodiff_thunk(ReverseSplitNoPrimal, Const{typeof(\)}, Duplicated, Duplicated{typeof(A)}, Duplicated{typeof(b)}) + + tape, primal, shadow = forward(Const(\), Duplicated(A, dA), Duplicated(b, db)) + + dy = Float64[17, 19] + copyto!(shadow, dy) + + pullback(Const(\), Duplicated(A, dA), Duplicated(b, db), tape) + + z = transpose(A) \ dy + + y = A \ b + @test dA ≈ (-z * transpose(y)) + @test db ≈ z + + db = zero(b) + + forward, pullback = Enzyme.autodiff_thunk(ReverseSplitNoPrimal, Const{typeof(\)}, Duplicated, Const{typeof(A)}, Duplicated{typeof(b)}) + + tape, primal, shadow = forward(Const(\), Const(A), Duplicated(b, db)) + + dy = Float64[17, 19] + copyto!(shadow, dy) + + pullback(Const(\), Const(A), Duplicated(b, db), tape) + + z = transpose(A) \ dy + + y = A \ b + @test db ≈ z + + dA = zero(A) + + forward, pullback = Enzyme.autodiff_thunk(ReverseSplitNoPrimal, Const{typeof(\)}, Duplicated, Duplicated{typeof(A)}, Const{typeof(b)}) + + tape, primal, shadow = forward(Const(\), Duplicated(A, dA), Const(b)) + + dy = Float64[17, 19] + copyto!(shadow, dy) + + pullback(Const(\), Duplicated(A, dA), Const(b), tape) + + z = transpose(A) \ dy + + y = A \ b + @test dA ≈ (-z * transpose(y)) +end + +end # InternalRules diff --git a/test/runtests.jl b/test/runtests.jl index 6231db870c..c1d7d49b72 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -76,6 +76,7 @@ include("typetree.jl") include("rrules.jl") include("kwrules.jl") include("kwrrules.jl") + include("internal_rules.jl") @static if VERSION ≥ v"1.9-" # XXX invalidation does not work on Julia 1.8 include("ruleinvalidation.jl") @@ -2615,60 +2616,6 @@ end end end -@testset "Linear Solve" begin - A = Float64[2 3; 5 7] - dA = zero(A) - b = Float64[11, 13] - db = zero(b) - - forward, pullback = Enzyme.autodiff_thunk(ReverseSplitNoPrimal, Const{typeof(\)}, Duplicated, Duplicated{typeof(A)}, Duplicated{typeof(b)}) - - tape, primal, shadow = forward(Const(\), Duplicated(A, dA), Duplicated(b, db)) - - dy = Float64[17, 19] - copyto!(shadow, dy) - - pullback(Const(\), Duplicated(A, dA), Duplicated(b, db), tape) - - z = transpose(A) \ dy - - y = A \ b - @test dA ≈ (-z * transpose(y)) - @test db ≈ z - - db = zero(b) - - forward, pullback = Enzyme.autodiff_thunk(ReverseSplitNoPrimal, Const{typeof(\)}, Duplicated, Const{typeof(A)}, Duplicated{typeof(b)}) - - tape, primal, shadow = forward(Const(\), Const(A), Duplicated(b, db)) - - dy = Float64[17, 19] - copyto!(shadow, dy) - - pullback(Const(\), Const(A), Duplicated(b, db), tape) - - z = transpose(A) \ dy - - y = A \ b - @test db ≈ z - - dA = zero(A) - - forward, pullback = Enzyme.autodiff_thunk(ReverseSplitNoPrimal, Const{typeof(\)}, Duplicated, Duplicated{typeof(A)}, Const{typeof(b)}) - - tape, primal, shadow = forward(Const(\), Duplicated(A, dA), Const(b)) - - dy = Float64[17, 19] - copyto!(shadow, dy) - - pullback(Const(\), Duplicated(A, dA), Const(b), tape) - - z = transpose(A) \ dy - - y = A \ b - @test dA ≈ (-z * transpose(y)) -end - @static if VERSION >= v"1.7-" @testset "hvcat_fill" begin ar = Matrix{Float64}(undef, 2, 3)