Skip to content

Commit

Permalink
sort! rules (#1000)
Browse files Browse the repository at this point in the history
* sort! rules

* Sort in augmented primal

* sort! rule tests

* Batched sort! rule

* Move A \ B rule test

* Fix after rebase

* Add missing end

---------

Co-authored-by: William Moses <[email protected]>
  • Loading branch information
jgreener64 and wsmoses authored Dec 17, 2023
1 parent 10d380b commit 770b064
Show file tree
Hide file tree
Showing 3 changed files with 162 additions and 54 deletions.
75 changes: 75 additions & 0 deletions src/internal_rules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -482,3 +482,78 @@ function EnzymeRules.reverse(config, func::Const{typeof(Base.hvcat_fill!)}, ::Ty
return (nothing, nothing)
end
end

function EnzymeRules.forward(
::Const{typeof(sort!)},
RT::Type{<:Union{Const, DuplicatedNoNeed, Duplicated}},
xs::Duplicated;
kwargs...
)
inds = sortperm(xs.val; kwargs...)
xs.val .= xs.val[inds]
xs.dval .= xs.dval[inds]
if RT <: Const
return xs.val
elseif RT <: DuplicatedNoNeed
return xs.dval
else
return xs
end
end

function EnzymeRules.forward(
::Const{typeof(sort!)},
RT::Type{<:Union{Const, BatchDuplicatedNoNeed, BatchDuplicated}},
xs::BatchDuplicated{T, N};
kwargs...
) where {T, N}
inds = sortperm(xs.val; kwargs...)
xs.val .= xs.val[inds]
for i in 1:N
xs.dval[i] .= xs.dval[i][inds]
end
if RT <: Const
return xs.val
elseif RT <: BatchDuplicatedNoNeed
return xs.dval
else
return xs
end
end

function EnzymeRules.augmented_primal(
config::EnzymeRules.ConfigWidth{1},
::Const{typeof(sort!)},
RT::Type{<:Union{Const, DuplicatedNoNeed, Duplicated}},
xs::Duplicated;
kwargs...
)
inds = sortperm(xs.val; kwargs...)
xs.val .= xs.val[inds]
xs.dval .= xs.dval[inds]
if EnzymeRules.needs_primal(config)
primal = xs.val
else
primal = nothing
end
if RT <: Const
shadow = nothing
else
shadow = xs.dval
end
return EnzymeRules.AugmentedReturn(primal, shadow, inds)
end

function EnzymeRules.reverse(
config::EnzymeRules.ConfigWidth{1},
::Const{typeof(sort!)},
RT::Type{<:Union{Const, DuplicatedNoNeed, Duplicated}},
tape,
xs::Duplicated;
kwargs...,
)
inds = tape
back_inds = sortperm(inds)
xs.dval .= xs.dval[back_inds]
return (nothing,)
end
86 changes: 86 additions & 0 deletions test/internal_rules.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
module InternalRules

using Enzyme
using Enzyme.EnzymeRules
using Test

@testset "Internal rules" begin
function f1(x)
a = [1.0, 3.0, x]
sort!(a)
return a[2]
end

@test autodiff(Forward, f1, Duplicated(2.0, 1.0))[1] == 1
@test autodiff(Forward, f1, BatchDuplicated(2.0, (1.0, 2.0)))[1] == (var"1"=1.0, var"2"=2.0)
@test autodiff(Reverse, f1, Active, Active(2.0))[1][1] == 1
@test autodiff(Forward, f1, Duplicated(4.0, 1.0))[1] == 0
@test autodiff(Forward, f1, BatchDuplicated(4.0, (1.0, 2.0)))[1] == (var"1"=0.0, var"2"=0.0)
@test autodiff(Reverse, f1, Active, Active(4.0))[1][1] == 0

function f2(x)
a = [1.0, -3.0, -x, -2x, x]
sort!(a; rev=true, lt=(x, y) -> abs(x) < abs(y) || (abs(x) == abs(y) && x < y))
return sum(a .* [1, 2, 3, 4, 5])
end

@test autodiff(Forward, f2, Duplicated(2.0, 1.0))[1] == -3
@test autodiff(Forward, f2, BatchDuplicated(2.0, (1.0, 2.0)))[1] == (var"1"=-3.0, var"2"=-6.0)
@test autodiff(Reverse, f2, Active, Active(2.0))[1][1] == -3
end

@testset "Linear Solve" begin
A = Float64[2 3; 5 7]
dA = zero(A)
b = Float64[11, 13]
db = zero(b)

forward, pullback = Enzyme.autodiff_thunk(ReverseSplitNoPrimal, Const{typeof(\)}, Duplicated, Duplicated{typeof(A)}, Duplicated{typeof(b)})

tape, primal, shadow = forward(Const(\), Duplicated(A, dA), Duplicated(b, db))

dy = Float64[17, 19]
copyto!(shadow, dy)

pullback(Const(\), Duplicated(A, dA), Duplicated(b, db), tape)

z = transpose(A) \ dy

y = A \ b
@test dA (-z * transpose(y))
@test db z

db = zero(b)

forward, pullback = Enzyme.autodiff_thunk(ReverseSplitNoPrimal, Const{typeof(\)}, Duplicated, Const{typeof(A)}, Duplicated{typeof(b)})

tape, primal, shadow = forward(Const(\), Const(A), Duplicated(b, db))

dy = Float64[17, 19]
copyto!(shadow, dy)

pullback(Const(\), Const(A), Duplicated(b, db), tape)

z = transpose(A) \ dy

y = A \ b
@test db z

dA = zero(A)

forward, pullback = Enzyme.autodiff_thunk(ReverseSplitNoPrimal, Const{typeof(\)}, Duplicated, Duplicated{typeof(A)}, Const{typeof(b)})

tape, primal, shadow = forward(Const(\), Duplicated(A, dA), Const(b))

dy = Float64[17, 19]
copyto!(shadow, dy)

pullback(Const(\), Duplicated(A, dA), Const(b), tape)

z = transpose(A) \ dy

y = A \ b
@test dA (-z * transpose(y))
end

end # InternalRules
55 changes: 1 addition & 54 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ include("typetree.jl")
include("rrules.jl")
include("kwrules.jl")
include("kwrrules.jl")
include("internal_rules.jl")
@static if VERSION v"1.9-"
# XXX invalidation does not work on Julia 1.8
include("ruleinvalidation.jl")
Expand Down Expand Up @@ -2615,60 +2616,6 @@ end
end
end

@testset "Linear Solve" begin
A = Float64[2 3; 5 7]
dA = zero(A)
b = Float64[11, 13]
db = zero(b)

forward, pullback = Enzyme.autodiff_thunk(ReverseSplitNoPrimal, Const{typeof(\)}, Duplicated, Duplicated{typeof(A)}, Duplicated{typeof(b)})

tape, primal, shadow = forward(Const(\), Duplicated(A, dA), Duplicated(b, db))

dy = Float64[17, 19]
copyto!(shadow, dy)

pullback(Const(\), Duplicated(A, dA), Duplicated(b, db), tape)

z = transpose(A) \ dy

y = A \ b
@test dA (-z * transpose(y))
@test db z

db = zero(b)

forward, pullback = Enzyme.autodiff_thunk(ReverseSplitNoPrimal, Const{typeof(\)}, Duplicated, Const{typeof(A)}, Duplicated{typeof(b)})

tape, primal, shadow = forward(Const(\), Const(A), Duplicated(b, db))

dy = Float64[17, 19]
copyto!(shadow, dy)

pullback(Const(\), Const(A), Duplicated(b, db), tape)

z = transpose(A) \ dy

y = A \ b
@test db z

dA = zero(A)

forward, pullback = Enzyme.autodiff_thunk(ReverseSplitNoPrimal, Const{typeof(\)}, Duplicated, Duplicated{typeof(A)}, Const{typeof(b)})

tape, primal, shadow = forward(Const(\), Duplicated(A, dA), Const(b))

dy = Float64[17, 19]
copyto!(shadow, dy)

pullback(Const(\), Duplicated(A, dA), Const(b), tape)

z = transpose(A) \ dy

y = A \ b
@test dA (-z * transpose(y))
end

@static if VERSION >= v"1.7-"
@testset "hvcat_fill" begin
ar = Matrix{Float64}(undef, 2, 3)
Expand Down

0 comments on commit 770b064

Please sign in to comment.