Skip to content

Commit

Permalink
Restructure + add tests
Browse files Browse the repository at this point in the history
  • Loading branch information
SBuercklin committed Jul 24, 2022
1 parent a42ab98 commit 80e2327
Show file tree
Hide file tree
Showing 5 changed files with 136 additions and 82 deletions.
3 changes: 2 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@ julia = "1"

[extras]
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["Test", "Random"]
test = ["Test", "Random", "SafeTestsets"]
11 changes: 11 additions & 0 deletions test/extras.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
using Unitful
using UnitfulChainRules
using ChainRulesCore

@testset "ustrip" begin
x = 5.0u"m"
Ω, pb = rrule(ustrip, x)

@test Ω == 5.0
@test last(pb(2.0)) == 2.0/u"m"
end
85 changes: 85 additions & 0 deletions test/rrules-frules-projection.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
using Unitful
using UnitfulChainRules

using ChainRulesCore: frule, rrule, ProjectTo, NoTangent

using Random
rng = VERSION >= v"1.7" ? Random.Xoshiro(0x0451) : Random.MersenneTwister(0x0451)

@testset "ProjectTo" begin
real_test(proj, val) = proj(val) == real(val)
complex_test(proj, val) = proj(val) == val
ru = randn(rng)
ruval = ru*u"W"
p_ruval = ProjectTo(ruval)

cu = randn(rng, ComplexF64)
cuval = cu*u"kg"
p_cuval = ProjectTo(cuval)

p_real = ProjectTo(ru)
p_complex = ProjectTo(cu)

δr = randn(rng)
δrval = δr*u"m"

δc = randn(rng, ComplexF64)
δcval = δc*u"L"

# Test projection onto real unitful quantities
for δ in (δrval, δcval, ru, cu)
@test real_test(p_ruval, δ)
end

# Test projection onto complex unitful quantities
for δ in (δrval, δcval, ru, cu)
@test complex_test(p_cuval, δ)
end

# Projecting Unitful quantities onto real values
@test p_real(δrval) == δrval
@test p_real(δcval) == real(δcval)

# Projecting Unitful quantities onto complex values
@test p_complex(δrval) == δrval
@test p_complex(δcval) == δcval
end

@testset "rrules" begin
@testset "Quantity rrule" begin
UT = typeof(1.0*u"W")
x = randn(rng)
δx = randn(rng)
Ω, pb = rrule(UT, x)
@test Ω == x * u"W"
@test pb(δx) == (NoTangent(), δx * u"W")
end
@testset "* rrule" begin
x = randn(rng)*u"W"
y = u"m"
z = u"L"
Ω, pb = rrule(*, x, y, z)
@test Ω == x*y*z
δ = randn(rng)
@test pb(δ) == (NoTangent(), δ*y*z, NoTangent(), NoTangent())
end
end

@testset "frules" begin
@testset "Quantity frule" begin
UT = typeof(1.0*u"W")
x = randn(rng)
δx = randn(rng)
X, ∂X = frule((nothing, δx), UT, x)
@test X == x * u"W"
@test ∂X == δx * u"W"
end
@testset "* frule" begin
x = randn(rng)*u"W"
δx = randn(rng)*u"L"
y = u"m"
X, ∂X = frule((nothing, δx, nothing), *, x, y)
@test X == x*y
@test ∂X == δx*y
end
end
89 changes: 8 additions & 81 deletions test/runtests.jl
Original file line number Diff line number Diff line change
@@ -1,88 +1,15 @@
using Unitful
using UnitfulChainRules

using ChainRulesCore: frule, rrule, ProjectTo, NoTangent

using Random
using SafeTestsets

using Test

rng = VERSION >= v"1.7" ? Random.Xoshiro(0x0451) : Random.MersenneTwister(0x0451)

@testset "ProjectTo" begin
real_test(proj, val) = proj(val) == real(val)
complex_test(proj, val) = proj(val) == val
ru = randn(rng)
ruval = ru*u"W"
p_ruval = ProjectTo(ruval)

cu = randn(rng, ComplexF64)
cuval = cu*u"kg"
p_cuval = ProjectTo(cuval)

p_real = ProjectTo(ru)
p_complex = ProjectTo(cu)

δr = randn(rng)
δrval = δr*u"m"

δc = randn(rng, ComplexF64)
δcval = δc*u"L"

# Test projection onto real unitful quantities
for δ in (δrval, δcval, ru, cu)
@test real_test(p_ruval, δ)
end

# Test projection onto complex unitful quantities
for δ in (δrval, δcval, ru, cu)
@test complex_test(p_cuval, δ)
end

# Projecting Unitful quantities onto real values
@test p_real(δrval) == δrval
@test p_real(δcval) == real(δcval)

# Projecting Unitful quantities onto complex values
@test p_complex(δrval) == δrval
@test p_complex(δcval) == δcval
@safetestset "rrules, frules, ProjectTo" begin
include("./rrules-frules-projection.jl")
end

@testset "rrules" begin
@testset "Quantity rrule" begin
UT = typeof(1.0*u"W")
x = randn(rng)
δx = randn(rng)
Ω, pb = rrule(UT, x)
@test Ω == x * u"W"
@test pb(δx) == (NoTangent(), δx * u"W")
end
@testset "* rrule" begin
x = randn(rng)*u"W"
y = u"m"
z = u"L"
Ω, pb = rrule(*, x, y, z)
@test Ω == x*y*z
δ = randn(rng)
@test pb(δ) == (NoTangent(), δ*y*z, NoTangent(), NoTangent())
end
@safetestset "Trig Operations" begin
include("./trig.jl")
end

@testset "frules" begin
@testset "Quantity frule" begin
UT = typeof(1.0*u"W")
x = randn(rng)
δx = randn(rng)
X, ∂X = frule((nothing, δx), UT, x)
@test X == x * u"W"
@test ∂X == δx * u"W"
end
@testset "* frule" begin
x = randn(rng)*u"W"
δx = randn(rng)*u"L"
y = u"m"
X, ∂X = frule((nothing, δx, nothing), *, x, y)
@test X == x*y
@test ∂X == δx*y
end
end
@safetestset "Extras" begin
include("./extras.jl")
end
30 changes: 30 additions & 0 deletions test/trig.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
using Unitful
using UnitfulChainRules

using ChainRulesCore

using Random
rng = VERSION >= v"1.7" ? Random.Xoshiro(0x0451) : Random.MersenneTwister(0x0451)

dsin(Ω, x) = cos(x)
dcos(Ω, x) = -sin(x)
dtan(Ω, x) = 1 + Ω^2
dcsc(Ω, x) = - Ω * cot(x)
dsec(Ω, x) = Ω * tan(x)
dcot(Ω, x) = -(1 + Ω^2)

for (f, df) in (
(:sin, :dsin), (:cos,:dcos), (:tan,:dtan), (:csc,:dcsc), (:sec,:dsec), (:cot,:dcot)
)
eval(
quote
@testset "$($f)" begin
x = rand(rng)u"°"

Ω, pb = rrule($f, x)
@test Ω == $f(x)
@test last(pb(1.0)) $df(Ω, x) * π/180u"°"
end
end
)
end

0 comments on commit 80e2327

Please sign in to comment.