Skip to content

Commit

Permalink
sort! rules
Browse files Browse the repository at this point in the history
  • Loading branch information
jgreener64 committed Aug 14, 2023
1 parent 14540ef commit 82e44ea
Showing 1 changed file with 53 additions and 0 deletions.
53 changes: 53 additions & 0 deletions src/internal_rules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -99,3 +99,56 @@ end
function EnzymeRules.inactive_noinl(::typeof(Base.size), args...)
return nothing
end

function EnzymeRules.forward(
::Const{typeof(sort!)},
RT::Type{<:Union{Const, DuplicatedNoNeed, Duplicated}},
xs::Duplicated;
kwargs...
)
inds = sortperm(xs.val; kwargs...)
xs.val .= xs.val[inds]
xs.dval .= xs.dval[inds]
if RT <: Const
return xs.val
elseif RT <: DuplicatedNoNeed
return xs.dval
else
return xs
end
end

function EnzymeRules.augmented_primal(
config::EnzymeRules.ConfigWidth{1},
::Const{typeof(sort!)},
RT::Type{<:Union{Const, DuplicatedNoNeed, Duplicated}},
xs::Duplicated;
kwargs...
)
inds = sortperm(xs.val; kwargs...)
if EnzymeRules.needs_primal(config)
primal = xs.val[inds]
else
primal = nothing
end
if RT <: Const
shadow = nothing
else
shadow = xs.dval[inds]
end
return EnzymeRules.AugmentedReturn(primal, shadow, inds)
end

function EnzymeRules.reverse(
config::EnzymeRules.ConfigWidth{1},
::Const{typeof(sort!)},
RT::Type{<:Union{Const, DuplicatedNoNeed, Duplicated}},
tape,
xs::Duplicated;
kwargs...,
)
inds = tape
back_inds = sortperm(inds)
xs.dval .= xs.dval[back_inds]
return (nothing,)
end

0 comments on commit 82e44ea

Please sign in to comment.