Skip to content

Commit

Permalink
abs rrule
Browse files Browse the repository at this point in the history
  • Loading branch information
Sam Buercklin committed Jul 28, 2022
1 parent 18d5597 commit c88449b
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 1 deletion.
6 changes: 5 additions & 1 deletion src/UnitfulChainRules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,17 @@ using Unitful: Quantity, Units, NoDims, FreeUnits
using ChainRulesCore: NoTangent, @scalar_rule
import ChainRulesCore: rrule, frule, ProjectTo

const REALCOMPLEX = Union{Real, Complex}

include("./rrules.jl")
include("./frules.jl")
include("./projection.jl")


include("./extras.jl") # Unitful-specific rules
include("./extras.jl") # extra Unitful-specific rules

include("./trig.jl") # sin, cos, tan, etc for degrees

include("./math.jl") # other math

end # module
8 changes: 8 additions & 0 deletions src/math.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
function rrule(::typeof(abs), x::Unitful.Quantity{T,D,U}) where {T<:REALCOMPLEX, D, U}
Ω = abs(x)
function abs_pullback(ΔΩ)
signx = isreal(x) ? sign(x) : x / ifelse(iszero(x), oneunit(Ω), Ω)
return (NoTangent(), signx * real(ΔΩ))
end
return Ω, abs_pullback
end
12 changes: 12 additions & 0 deletions test/math.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
using Unitful
using UnitfulChainRules

using ChainRulesCore

@testset "abs" begin
z = (1 + im)u"W"
Ω, pb = rrule(abs, z)

@test Ω sqrt(2)u"W"
@test last(pb(1.0)) (1 + im)/sqrt(2)
end

0 comments on commit c88449b

Please sign in to comment.