Skip to content

Commit

Permalink
Fix exponentiation for NaNMath.pow (#717)
Browse files Browse the repository at this point in the history
* fix NaNMath exponentiation

* reuse code

* fix

* add tests

* Update src/dual.jl

Co-authored-by: David Widmann <[email protected]>

* import NaNMath

* oops, no begin

* Update test/GradientTest.jl

Co-authored-by: David Widmann <[email protected]>

---------

Co-authored-by: David Widmann <[email protected]>
  • Loading branch information
jClugstor and devmotion authored Nov 8, 2024
1 parent 7e9d778 commit 8eaba05
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 3 deletions.
6 changes: 3 additions & 3 deletions src/dual.jl
Original file line number Diff line number Diff line change
Expand Up @@ -552,7 +552,7 @@ end
# exponentiation #
#----------------#

for f in (:(Base.:^), :(NaNMath.pow))
for (f, log) in ((:(Base.:^), :(Base.log)), (:(NaNMath.pow), :(NaNMath.log)))
@eval begin
@define_binary_dual_op(
$f,
Expand All @@ -565,7 +565,7 @@ for f in (:(Base.:^), :(NaNMath.pow))
elseif iszero(vx) && vy > 0
logval = zero(vx)
else
logval = expv * log(vx)
logval = expv * ($log)(vx)
end
new_partials = _mul_partials(partials(x), partials(y), powval, logval)
return Dual{Txy}(expv, new_partials)
Expand All @@ -583,7 +583,7 @@ for f in (:(Base.:^), :(NaNMath.pow))
begin
v = value(y)
expv = ($f)(x, v)
deriv = (iszero(x) && v > 0) ? zero(expv) : expv*log(x)
deriv = (iszero(x) && v > 0) ? zero(expv) : expv*($log)(x)
return Dual{Ty}(expv, deriv * partials(y))
end
)
Expand Down
12 changes: 12 additions & 0 deletions test/DerivativeTest.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
module DerivativeTest

import Calculus
import NaNMath

using Test
using Random
Expand Down Expand Up @@ -93,6 +94,17 @@ end
@test (x -> ForwardDiff.derivative(y -> x^y, 1.5))(0.0) === 0.0
end

@testset "exponentiation with NaNMath" begin
@test isnan(ForwardDiff.derivative(x -> NaNMath.pow(NaN, x), 1.0))
@test isnan(ForwardDiff.derivative(x -> NaNMath.pow(x,NaN), 1.0))
@test !isnan(ForwardDiff.derivative(x -> NaNMath.pow(1.0, x),1.0))
@test isnan(ForwardDiff.derivative(x -> NaNMath.pow(x,0.5), -1.0))

@test isnan(ForwardDiff.derivative(x -> x^NaN, 2.0))
@test ForwardDiff.derivative(x -> x^2.0,2.0) == 4.0
@test_throws DomainError ForwardDiff.derivative(x -> x^0.5, -1.0)
end

@testset "dimension error for derivative" begin
@test_throws DimensionMismatch ForwardDiff.derivative(sum, fill(2pi, 3))
end
Expand Down
11 changes: 11 additions & 0 deletions test/GradientTest.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
module GradientTest

import Calculus
import NaNMath

using Test
using LinearAlgebra
Expand Down Expand Up @@ -200,6 +201,16 @@ end
@test ForwardDiff.gradient(L -> logdet(L), Matrix(L)) [1.0 -1.3333333333333337; 0.0 1.666666666666667]
end

@testset "gradient for exponential with NaNMath" begin
@test isnan(ForwardDiff.gradient(x -> NaNMath.pow(x[1],x[1]), [NaN, 1.0])[1])
@test ForwardDiff.gradient(x -> NaNMath.pow(x[1], x[2]), [1.0, 1.0]) == [1.0, 0.0]
@test isnan(ForwardDiff.gradient((x) -> NaNMath.pow(x[1], x[2]), [-1.0, 0.5])[1])

@test isnan(ForwardDiff.gradient(x -> x[1]^x[2], [NaN, 1.0])[1])
@test ForwardDiff.gradient(x -> x[1]^x[2], [1.0, 1.0]) == [1.0, 0.0]
@test_throws DomainError ForwardDiff.gradient(x -> x[1]^x[2], [-1.0, 0.5])
end

@testset "branches in mul!" begin
a, b = rand(3,3), rand(3,3)

Expand Down

0 comments on commit 8eaba05

Please sign in to comment.