Skip to content

Commit

Permalink
Merge pull request #9 from SBuercklin/sam/extra-functions
Browse files Browse the repository at this point in the history
uconvert, ustrip, trig over degrees, abs rules
  • Loading branch information
SBuercklin authored Jul 28, 2022
2 parents 0cf3fbe + ce64dfb commit 4e54f5f
Show file tree
Hide file tree
Showing 10 changed files with 218 additions and 85 deletions.
5 changes: 3 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "UnitfulChainRules"
uuid = "f31437dd-25a7-4345-875f-756556e6935d"
authors = ["Sam Buercklin <[email protected]>"]
version = "0.1.0"
version = "0.1.1"

[deps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Expand All @@ -14,7 +14,8 @@ julia = "1"

[extras]
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["Test", "Random"]
test = ["Test", "Random", "SafeTestsets"]
13 changes: 11 additions & 2 deletions src/UnitfulChainRules.jl
Original file line number Diff line number Diff line change
@@ -1,12 +1,21 @@
module UnitfulChainRules

using Unitful
using Unitful: Quantity, Units
using ChainRulesCore: NoTangent
using Unitful: Quantity, Units, NoDims, FreeUnits
using ChainRulesCore: NoTangent, @scalar_rule
import ChainRulesCore: rrule, frule, ProjectTo

const REALCOMPLEX = Union{Real, Complex}

include("./rrules.jl")
include("./frules.jl")
include("./projection.jl")


include("./extras.jl") # extra Unitful-specific rules

include("./trig.jl") # sin, cos, tan, etc for degrees

include("./math.jl") # other math

end # module
25 changes: 25 additions & 0 deletions src/extras.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
# Identity operation for non-Quantities
rrule(::typeof(ustrip), x::Number) = x, (Δ) -> (NoTangent(), Δ * one(x))

# Divide by the stripped units to backprop
function rrule(::typeof(ustrip), x::Quantity{T,D,U}) where {T,D,U}
ustripped = ustrip(x)
project_x = ProjectTo(x)
invU = inv(U())

ustrip_pb(Δ) = (NoTangent(), project_x* invU))

return ustripped, ustrip_pb
end

function rrule(::typeof(uconvert), u::FreeUnits{N,D,A}, x::TX) where {N,D,A,TX}
x_convert = uconvert(u, x)
conversion = uconvert(u, oneunit(x)) / oneunit(x)
project_x = ProjectTo(x)

function uconvert_pb(Δ)
return (NoTangent(), NoTangent(), project_x* conversion))
end

return x_convert, uconvert_pb
end
8 changes: 8 additions & 0 deletions src/math.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
function rrule(::typeof(abs), x::Unitful.Quantity{T,D,U}) where {T<:REALCOMPLEX, D, U}
Ω = abs(x)
function abs_pullback(ΔΩ)
signx = isreal(x) ? sign(x) : x / ifelse(iszero(x), oneunit(Ω), Ω)
return (NoTangent(), signx * real(ΔΩ))
end
return Ω, abs_pullback
end
17 changes: 17 additions & 0 deletions src/trig.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
###########################
#=
Trigonometric Rules for Degrees
Let dx be differential in radians/dimensionless, dx° be in degrees
df/dx° = df/dx * dx/dx° = df/dx * π/180°
=#
###########################
const DEGREE_QUANTITY = Quantity{<:Number,NoDims,typeof(u"°")}
const TO_RAD = π/180u"°"

@scalar_rule sin(x::DEGREE_QUANTITY) cos(x) * TO_RAD
@scalar_rule cos(x::DEGREE_QUANTITY) -sin(x) * TO_RAD
@scalar_rule tan(x::DEGREE_QUANTITY) (1 + Ω^2) * TO_RAD
@scalar_rule csc(x::DEGREE_QUANTITY) -Ω * cot(x) * TO_RAD
@scalar_rule sec(x::DEGREE_QUANTITY) Ω * tan(x) * TO_RAD
@scalar_rule cot(x::DEGREE_QUANTITY) -(1 + Ω^2) * TO_RAD
19 changes: 19 additions & 0 deletions test/extras.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
using Unitful
using UnitfulChainRules
using ChainRulesCore

@testset "ustrip" begin
x = 5.0u"m"
Ω, pb = rrule(ustrip, x)

@test Ω == 5.0
@test last(pb(2.0)) == 2.0/u"m"
end

@testset "uconvert" begin
x = 30.0u"°"
Ω, pb = rrule(uconvert, u"rad", x)

@test Ω /6)u"rad"
@test last(pb(1.0)) π*u"rad"/180u"°"
end
12 changes: 12 additions & 0 deletions test/math.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
using Unitful
using UnitfulChainRules

using ChainRulesCore

@testset "abs" begin
z = (1 + im)u"W"
Ω, pb = rrule(abs, z)

@test Ω sqrt(2)u"W"
@test last(pb(1.0)) (1 + im)/sqrt(2)
end
85 changes: 85 additions & 0 deletions test/rrules-frules-projection.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
using Unitful
using UnitfulChainRules

using ChainRulesCore: frule, rrule, ProjectTo, NoTangent

using Random
rng = VERSION >= v"1.7" ? Random.Xoshiro(0x0451) : Random.MersenneTwister(0x0451)

@testset "ProjectTo" begin
real_test(proj, val) = proj(val) == real(val)
complex_test(proj, val) = proj(val) == val
ru = randn(rng)
ruval = ru*u"W"
p_ruval = ProjectTo(ruval)

cu = randn(rng, ComplexF64)
cuval = cu*u"kg"
p_cuval = ProjectTo(cuval)

p_real = ProjectTo(ru)
p_complex = ProjectTo(cu)

δr = randn(rng)
δrval = δr*u"m"

δc = randn(rng, ComplexF64)
δcval = δc*u"L"

# Test projection onto real unitful quantities
for δ in (δrval, δcval, ru, cu)
@test real_test(p_ruval, δ)
end

# Test projection onto complex unitful quantities
for δ in (δrval, δcval, ru, cu)
@test complex_test(p_cuval, δ)
end

# Projecting Unitful quantities onto real values
@test p_real(δrval) == δrval
@test p_real(δcval) == real(δcval)

# Projecting Unitful quantities onto complex values
@test p_complex(δrval) == δrval
@test p_complex(δcval) == δcval
end

@testset "rrules" begin
@testset "Quantity rrule" begin
UT = typeof(1.0*u"W")
x = randn(rng)
δx = randn(rng)
Ω, pb = rrule(UT, x)
@test Ω == x * u"W"
@test pb(δx) == (NoTangent(), δx * u"W")
end
@testset "* rrule" begin
x = randn(rng)*u"W"
y = u"m"
z = u"L"
Ω, pb = rrule(*, x, y, z)
@test Ω == x*y*z
δ = randn(rng)
@test pb(δ) == (NoTangent(), δ*y*z, NoTangent(), NoTangent())
end
end

@testset "frules" begin
@testset "Quantity frule" begin
UT = typeof(1.0*u"W")
x = randn(rng)
δx = randn(rng)
X, ∂X = frule((nothing, δx), UT, x)
@test X == x * u"W"
@test ∂X == δx * u"W"
end
@testset "* frule" begin
x = randn(rng)*u"W"
δx = randn(rng)*u"L"
y = u"m"
X, ∂X = frule((nothing, δx, nothing), *, x, y)
@test X == x*y
@test ∂X == δx*y
end
end
89 changes: 8 additions & 81 deletions test/runtests.jl
Original file line number Diff line number Diff line change
@@ -1,88 +1,15 @@
using Unitful
using UnitfulChainRules

using ChainRulesCore: frule, rrule, ProjectTo, NoTangent

using Random
using SafeTestsets

using Test

rng = VERSION >= v"1.7" ? Random.Xoshiro(0x0451) : Random.MersenneTwister(0x0451)

@testset "ProjectTo" begin
real_test(proj, val) = proj(val) == real(val)
complex_test(proj, val) = proj(val) == val
ru = randn(rng)
ruval = ru*u"W"
p_ruval = ProjectTo(ruval)

cu = randn(rng, ComplexF64)
cuval = cu*u"kg"
p_cuval = ProjectTo(cuval)

p_real = ProjectTo(ru)
p_complex = ProjectTo(cu)

δr = randn(rng)
δrval = δr*u"m"

δc = randn(rng, ComplexF64)
δcval = δc*u"L"

# Test projection onto real unitful quantities
for δ in (δrval, δcval, ru, cu)
@test real_test(p_ruval, δ)
end

# Test projection onto complex unitful quantities
for δ in (δrval, δcval, ru, cu)
@test complex_test(p_cuval, δ)
end

# Projecting Unitful quantities onto real values
@test p_real(δrval) == δrval
@test p_real(δcval) == real(δcval)

# Projecting Unitful quantities onto complex values
@test p_complex(δrval) == δrval
@test p_complex(δcval) == δcval
@safetestset "rrules, frules, ProjectTo" begin
include("./rrules-frules-projection.jl")
end

@testset "rrules" begin
@testset "Quantity rrule" begin
UT = typeof(1.0*u"W")
x = randn(rng)
δx = randn(rng)
Ω, pb = rrule(UT, x)
@test Ω == x * u"W"
@test pb(δx) == (NoTangent(), δx * u"W")
end
@testset "* rrule" begin
x = randn(rng)*u"W"
y = u"m"
z = u"L"
Ω, pb = rrule(*, x, y, z)
@test Ω == x*y*z
δ = randn(rng)
@test pb(δ) == (NoTangent(), δ*y*z, NoTangent(), NoTangent())
end
@safetestset "Trig Operations" begin
include("./trig.jl")
end

@testset "frules" begin
@testset "Quantity frule" begin
UT = typeof(1.0*u"W")
x = randn(rng)
δx = randn(rng)
X, ∂X = frule((nothing, δx), UT, x)
@test X == x * u"W"
@test ∂X == δx * u"W"
end
@testset "* frule" begin
x = randn(rng)*u"W"
δx = randn(rng)*u"L"
y = u"m"
X, ∂X = frule((nothing, δx, nothing), *, x, y)
@test X == x*y
@test ∂X == δx*y
end
end
@safetestset "Extras" begin
include("./extras.jl")
end
30 changes: 30 additions & 0 deletions test/trig.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
using Unitful
using UnitfulChainRules

using ChainRulesCore

using Random
rng = VERSION >= v"1.7" ? Random.Xoshiro(0x0451) : Random.MersenneTwister(0x0451)

dsin(Ω, x) = cos(x)
dcos(Ω, x) = -sin(x)
dtan(Ω, x) = 1 + Ω^2
dcsc(Ω, x) = - Ω * cot(x)
dsec(Ω, x) = Ω * tan(x)
dcot(Ω, x) = -(1 + Ω^2)

for (f, df) in (
(:sin, :dsin), (:cos,:dcos), (:tan,:dtan), (:csc,:dcsc), (:sec,:dsec), (:cot,:dcot)
)
eval(
quote
@testset "$($f)" begin
x = rand(rng)u"°"

Ω, pb = rrule($f, x)
@test Ω == $f(x)
@test last(pb(1.0)) $df(Ω, x) * π/180u"°"
end
end
)
end

2 comments on commit 4e54f5f

@SBuercklin
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/65190

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.1.1 -m "<description of version>" 4e54f5fea3c6bd1f5dda880a69c33af3038ee5b3
git push origin v0.1.1

Please sign in to comment.