From 7d55c2e103db22960e2f65c95d3c59c29ec8a5d6 Mon Sep 17 00:00:00 2001 From: Joe Greener Date: Sun, 29 Oct 2023 23:27:08 +0000 Subject: [PATCH] Batched sort! rule --- src/internal_rules.jl | 20 ++++++++++++++++++++ test/internal_rules.jl | 9 ++++++--- 2 files changed, 26 insertions(+), 3 deletions(-) diff --git a/src/internal_rules.jl b/src/internal_rules.jl index b8304f9a58a..8585005c77b 100644 --- a/src/internal_rules.jl +++ b/src/internal_rules.jl @@ -433,6 +433,26 @@ function EnzymeRules.forward( end end +function EnzymeRules.forward( + ::Const{typeof(sort!)}, + RT::Type{<:Union{Const, BatchDuplicatedNoNeed, BatchDuplicated}}, + xs::BatchDuplicated{T, N}; + kwargs... + ) where {T, N} + inds = sortperm(xs.val; kwargs...) + xs.val .= xs.val[inds] + for i in 1:N + xs.dval[i] .= xs.dval[i][inds] + end + if RT <: Const + return xs.val + elseif RT <: BatchDuplicatedNoNeed + return xs.dval + else + return xs + end +end + function EnzymeRules.augmented_primal( config::EnzymeRules.ConfigWidth{1}, ::Const{typeof(sort!)}, diff --git a/test/internal_rules.jl b/test/internal_rules.jl index e3793c21b3b..6698dbf81b5 100644 --- a/test/internal_rules.jl +++ b/test/internal_rules.jl @@ -11,9 +11,11 @@ using Test return a[2] end - @test autodiff(Forward, f1, Duplicated(2.0, 1.0))[1] == 1 + @test autodiff(Forward, f1, Duplicated(2.0, 1.0))[1] == 1 + @test autodiff(Forward, f1, BatchDuplicated(2.0, (1.0, 2.0)))[1] == (var"1"=1.0, var"2"=2.0) @test autodiff(Reverse, f1, Active, Active(2.0))[1][1] == 1 - @test autodiff(Forward, f1, Duplicated(4.0, 1.0))[1] == 0 + @test autodiff(Forward, f1, Duplicated(4.0, 1.0))[1] == 0 + @test autodiff(Forward, f1, BatchDuplicated(4.0, (1.0, 2.0)))[1] == (var"1"=0.0, var"2"=0.0) @test autodiff(Reverse, f1, Active, Active(4.0))[1][1] == 0 function f2(x) @@ -22,7 +24,8 @@ using Test return sum(a .* [1, 2, 3, 4, 5]) end - @test autodiff(Forward, f2, Duplicated(2.0, 1.0))[1] == -3 + @test autodiff(Forward, f2, Duplicated(2.0, 1.0))[1] == -3 + @test autodiff(Forward, f2, BatchDuplicated(2.0, (1.0, 2.0)))[1] == (var"1"=-3.0, var"2"=-6.0) @test autodiff(Reverse, f2, Active, Active(2.0))[1][1] == -3 end