Skip to content

Commit

Permalink
updates for Enzyme changes
Browse files Browse the repository at this point in the history
  • Loading branch information
jgreener64 authored and wsmoses committed Dec 7, 2024
1 parent d3dd701 commit 34e06cd
Showing 1 changed file with 14 additions and 25 deletions.
39 changes: 14 additions & 25 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -797,29 +797,21 @@ end
@test autodiff(Reverse, f26, Active, Active(2.0))[1][1] == 2
@test autodiff(Forward, f26, Duplicated(2.0, 1.0))[1] == 2

f27(x) = sum(diff([0.0 x; 1.0 2x]; dims=2))
f27(x) = repeat([x 3x], 3)[2, 2]
@test autodiff(Reverse, f27, Active, Active(2.0))[1][1] == 3
@test autodiff(Forward, f27, Duplicated(2.0, 1.0))[1] == 3

f28(x) = repeat([x 3x], 3)[2, 2]
@test autodiff(Reverse, f28, Active, Active(2.0))[1][1] == 3
@test autodiff(Forward, f28, Duplicated(2.0, 1.0))[1] == 3
f28(x) = x * sum(trues(4, 3))
@test autodiff(Reverse, f28, Active, Active(2.0))[1][1] == 12
@test autodiff(Forward, f28, Duplicated(2.0, 1.0))[1] == 12

f29(x) = rot180([x 2x; 3x 4x], 3)[1, 1]
@test autodiff(Reverse, f29, Active, Active(2.0))[1][1] == 4
@test autodiff(Forward, f29, Duplicated(2.0, 1.0))[1] == 4
f29(x) = sum(Set([1.0, x, 2x, x]))
@test autodiff(Reverse, f29, Active, Active(2.0))[1][1] == 3
@test autodiff(Forward, f29, Duplicated(2.0, 1.0))[1] == 3

f30(x) = x * sum(trues(4, 3))
@test autodiff(Reverse, f30, Active, Active(2.0))[1][1] == 12
@test autodiff(Forward, f30, Duplicated(2.0, 1.0))[1] == 12

f31(x) = sum(Set([1.0, x, 2x, x]))
@test autodiff(Reverse, f31, Active, Active(2.0))[1][1] == 3
@test autodiff(Forward, f31, Duplicated(2.0, 1.0))[1] == 3

f32(x) = reverse([x 2.0 3x])[1]
@test autodiff(Reverse, f32, Active, Active(2.0))[1][1] == 3
@test autodiff(Forward, f32, Duplicated(2.0, 1.0))[1] == 3
f30(x) = reverse([x 2.0 3x])[1]
@test autodiff(Reverse, f30, Active, Active(2.0))[1][1] == 3
@test autodiff(Forward, f30, Duplicated(2.0, 1.0))[1] == 3
end

function deadarg_pow(z::T, i) where {T<:Real}
Expand Down Expand Up @@ -885,7 +877,7 @@ end
@test autodiff(Forward, (x,y) -> autodiff(Forward, Const(tonest), Duplicated(x, 1.0), Const(y))[1], Const(1.0), Duplicated(2.0, 1.0))[1] 2.0

f_nest(x) = 2 * x^4
deriv(f, x) = first(first(autodiff_deferred(Reverse, f, Active(x))))
deriv(f, x) = first(first(autodiff(Reverse, f, Active(x))))
f′(x) = deriv(f_nest, x)
f′′(x) = deriv(f′, x)

Expand Down Expand Up @@ -3757,12 +3749,9 @@ end
@test autodiff(Reverse, f8, Active, Active(1.5))[1][1] == 0
@test autodiff(Forward, f8, Duplicated(1.5, 1.0))[1] == 0

# On Julia 1.6 the gradients are wrong (0.7 not 1.2) and on 1.7 it errors
@static if VERSION v"1.8-"
f9(x) = sum(quantile([1.0, x], [0.5, 0.7]))
@test autodiff(Reverse, f9, Active, Active(2.0))[1][1] == 1.2
@test autodiff(Forward, f9, Duplicated(2.0, 1.0))[1] == 1.2
end
f9(x) = sum(quantile([1.0, x], [0.5, 0.7]))
@test autodiff(Reverse, f9, Active, Active(2.0))[1][1] == 1.2
@test autodiff(Forward, f9, Duplicated(2.0, 1.0))[1] == 1.2
end

@testset "hvcat_fill" begin
Expand Down

0 comments on commit 34e06cd

Please sign in to comment.