Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

sort! rules #1000

Merged
merged 8 commits into from
Dec 17, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 75 additions & 0 deletions src/internal_rules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -482,3 +482,78 @@ function EnzymeRules.reverse(config, func::Const{typeof(Base.hvcat_fill!)}, ::Ty
return (nothing, nothing)
end
end

function EnzymeRules.forward(
::Const{typeof(sort!)},
RT::Type{<:Union{Const, DuplicatedNoNeed, Duplicated}},
xs::Duplicated;
kwargs...
)
inds = sortperm(xs.val; kwargs...)
xs.val .= xs.val[inds]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like sortperm itself just makes a list of 1:N, then sorts that in place using the original data as the order: https://github.com/JuliaLang/julia/blob/750df9fb5bede16f321f5d5405943d12aec7b83e/base/sort.jl#L1756

Could we change this to first sort the derivative array in place using the primal array as the order, then do the actual sort on the primal? That way we don't have that temporary

xs.dval .= xs.dval[inds]
if RT <: Const
return xs.val
jgreener64 marked this conversation as resolved.
Show resolved Hide resolved
elseif RT <: DuplicatedNoNeed
return xs.dval
else
return xs
end
end

function EnzymeRules.forward(
::Const{typeof(sort!)},
RT::Type{<:Union{Const, BatchDuplicatedNoNeed, BatchDuplicated}},
xs::BatchDuplicated{T, N};
kwargs...
) where {T, N}
inds = sortperm(xs.val; kwargs...)
xs.val .= xs.val[inds]
for i in 1:N
xs.dval[i] .= xs.dval[i][inds]
end
if RT <: Const
return xs.val
elseif RT <: BatchDuplicatedNoNeed
return xs.dval
else
return xs
end
end

function EnzymeRules.augmented_primal(
config::EnzymeRules.ConfigWidth{1},
::Const{typeof(sort!)},
RT::Type{<:Union{Const, DuplicatedNoNeed, Duplicated}},
xs::Duplicated;
kwargs...
)
inds = sortperm(xs.val; kwargs...)
xs.val .= xs.val[inds]
xs.dval .= xs.dval[inds]
if EnzymeRules.needs_primal(config)
jgreener64 marked this conversation as resolved.
Show resolved Hide resolved
primal = xs.val
else
primal = nothing
end
if RT <: Const
shadow = nothing
else
shadow = xs.dval
end
return EnzymeRules.AugmentedReturn(primal, shadow, inds)
end

function EnzymeRules.reverse(
config::EnzymeRules.ConfigWidth{1},
::Const{typeof(sort!)},
RT::Type{<:Union{Const, DuplicatedNoNeed, Duplicated}},
tape,
xs::Duplicated;
kwargs...,
)
inds = tape
back_inds = sortperm(inds)
xs.dval .= xs.dval[back_inds]
return (nothing,)
end
86 changes: 86 additions & 0 deletions test/internal_rules.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
module InternalRules

using Enzyme
using Enzyme.EnzymeRules
using Test

@testset "Internal rules" begin
function f1(x)
a = [1.0, 3.0, x]
sort!(a)
return a[2]
end

@test autodiff(Forward, f1, Duplicated(2.0, 1.0))[1] == 1
@test autodiff(Forward, f1, BatchDuplicated(2.0, (1.0, 2.0)))[1] == (var"1"=1.0, var"2"=2.0)
@test autodiff(Reverse, f1, Active, Active(2.0))[1][1] == 1
@test autodiff(Forward, f1, Duplicated(4.0, 1.0))[1] == 0
@test autodiff(Forward, f1, BatchDuplicated(4.0, (1.0, 2.0)))[1] == (var"1"=0.0, var"2"=0.0)
@test autodiff(Reverse, f1, Active, Active(4.0))[1][1] == 0

function f2(x)
a = [1.0, -3.0, -x, -2x, x]
sort!(a; rev=true, lt=(x, y) -> abs(x) < abs(y) || (abs(x) == abs(y) && x < y))
return sum(a .* [1, 2, 3, 4, 5])
end

@test autodiff(Forward, f2, Duplicated(2.0, 1.0))[1] == -3
@test autodiff(Forward, f2, BatchDuplicated(2.0, (1.0, 2.0)))[1] == (var"1"=-3.0, var"2"=-6.0)
@test autodiff(Reverse, f2, Active, Active(2.0))[1][1] == -3
end

@testset "Linear Solve" begin
A = Float64[2 3; 5 7]
dA = zero(A)
b = Float64[11, 13]
db = zero(b)

forward, pullback = Enzyme.autodiff_thunk(ReverseSplitNoPrimal, Const{typeof(\)}, Duplicated, Duplicated{typeof(A)}, Duplicated{typeof(b)})

tape, primal, shadow = forward(Const(\), Duplicated(A, dA), Duplicated(b, db))

dy = Float64[17, 19]
copyto!(shadow, dy)

pullback(Const(\), Duplicated(A, dA), Duplicated(b, db), tape)

z = transpose(A) \ dy

y = A \ b
@test dA ≈ (-z * transpose(y))
@test db ≈ z

db = zero(b)

forward, pullback = Enzyme.autodiff_thunk(ReverseSplitNoPrimal, Const{typeof(\)}, Duplicated, Const{typeof(A)}, Duplicated{typeof(b)})

tape, primal, shadow = forward(Const(\), Const(A), Duplicated(b, db))

dy = Float64[17, 19]
copyto!(shadow, dy)

pullback(Const(\), Const(A), Duplicated(b, db), tape)

z = transpose(A) \ dy

y = A \ b
@test db ≈ z

dA = zero(A)

forward, pullback = Enzyme.autodiff_thunk(ReverseSplitNoPrimal, Const{typeof(\)}, Duplicated, Duplicated{typeof(A)}, Const{typeof(b)})

tape, primal, shadow = forward(Const(\), Duplicated(A, dA), Const(b))

dy = Float64[17, 19]
copyto!(shadow, dy)

pullback(Const(\), Duplicated(A, dA), Const(b), tape)

z = transpose(A) \ dy

y = A \ b
@test dA ≈ (-z * transpose(y))
end

end # InternalRules
55 changes: 1 addition & 54 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ include("typetree.jl")
include("rrules.jl")
include("kwrules.jl")
include("kwrrules.jl")
include("internal_rules.jl")
jgreener64 marked this conversation as resolved.
Show resolved Hide resolved
@static if VERSION ≥ v"1.9-"
# XXX invalidation does not work on Julia 1.8
include("ruleinvalidation.jl")
Expand Down Expand Up @@ -2615,60 +2616,6 @@ end
end
end

@testset "Linear Solve" begin
A = Float64[2 3; 5 7]
dA = zero(A)
b = Float64[11, 13]
db = zero(b)

forward, pullback = Enzyme.autodiff_thunk(ReverseSplitNoPrimal, Const{typeof(\)}, Duplicated, Duplicated{typeof(A)}, Duplicated{typeof(b)})

tape, primal, shadow = forward(Const(\), Duplicated(A, dA), Duplicated(b, db))

dy = Float64[17, 19]
copyto!(shadow, dy)

pullback(Const(\), Duplicated(A, dA), Duplicated(b, db), tape)

z = transpose(A) \ dy

y = A \ b
@test dA ≈ (-z * transpose(y))
@test db ≈ z

db = zero(b)

forward, pullback = Enzyme.autodiff_thunk(ReverseSplitNoPrimal, Const{typeof(\)}, Duplicated, Const{typeof(A)}, Duplicated{typeof(b)})

tape, primal, shadow = forward(Const(\), Const(A), Duplicated(b, db))

dy = Float64[17, 19]
copyto!(shadow, dy)

pullback(Const(\), Const(A), Duplicated(b, db), tape)

z = transpose(A) \ dy

y = A \ b
@test db ≈ z

dA = zero(A)

forward, pullback = Enzyme.autodiff_thunk(ReverseSplitNoPrimal, Const{typeof(\)}, Duplicated, Duplicated{typeof(A)}, Const{typeof(b)})

tape, primal, shadow = forward(Const(\), Duplicated(A, dA), Const(b))

dy = Float64[17, 19]
copyto!(shadow, dy)

pullback(Const(\), Duplicated(A, dA), Const(b), tape)

z = transpose(A) \ dy

y = A \ b
@test dA ≈ (-z * transpose(y))
end

@static if VERSION >= v"1.7-"
@testset "hvcat_fill" begin
ar = Matrix{Float64}(undef, 2, 3)
Expand Down
Loading