diff --git a/Project.toml b/Project.toml index 69dab36..c704f40 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "UnitfulChainRules" uuid = "f31437dd-25a7-4345-875f-756556e6935d" authors = ["Sam Buercklin "] -version = "0.1.0" +version = "0.1.1" [deps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" @@ -14,7 +14,8 @@ julia = "1" [extras] Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" +SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [targets] -test = ["Test", "Random"] +test = ["Test", "Random", "SafeTestsets"] diff --git a/src/UnitfulChainRules.jl b/src/UnitfulChainRules.jl index a82ffd8..11a7430 100644 --- a/src/UnitfulChainRules.jl +++ b/src/UnitfulChainRules.jl @@ -1,12 +1,21 @@ module UnitfulChainRules using Unitful -using Unitful: Quantity, Units -using ChainRulesCore: NoTangent +using Unitful: Quantity, Units, NoDims, FreeUnits +using ChainRulesCore: NoTangent, @scalar_rule import ChainRulesCore: rrule, frule, ProjectTo +const REALCOMPLEX = Union{Real, Complex} + include("./rrules.jl") include("./frules.jl") include("./projection.jl") + +include("./extras.jl") # extra Unitful-specific rules + +include("./trig.jl") # sin, cos, tan, etc for degrees + +include("./math.jl") # other math + end # module diff --git a/src/extras.jl b/src/extras.jl new file mode 100644 index 0000000..e590576 --- /dev/null +++ b/src/extras.jl @@ -0,0 +1,25 @@ +# Identity operation for non-Quantities +rrule(::typeof(ustrip), x::Number) = x, (Δ) -> (NoTangent(), Δ * one(x)) + +# Divide by the stripped units to backprop +function rrule(::typeof(ustrip), x::Quantity{T,D,U}) where {T,D,U} + ustripped = ustrip(x) + project_x = ProjectTo(x) + invU = inv(U()) + + ustrip_pb(Δ) = (NoTangent(), project_x(Δ * invU)) + + return ustripped, ustrip_pb +end + +function rrule(::typeof(uconvert), u::FreeUnits{N,D,A}, x::TX) where {N,D,A,TX} + x_convert = uconvert(u, x) + conversion = uconvert(u, oneunit(x)) / oneunit(x) + project_x = ProjectTo(x) + + function uconvert_pb(Δ) + return (NoTangent(), NoTangent(), project_x(Δ * conversion)) + end + + return x_convert, uconvert_pb +end \ No newline at end of file diff --git a/src/math.jl b/src/math.jl new file mode 100644 index 0000000..e419b60 --- /dev/null +++ b/src/math.jl @@ -0,0 +1,8 @@ +function rrule(::typeof(abs), x::Unitful.Quantity{T,D,U}) where {T<:REALCOMPLEX, D, U} + Ω = abs(x) + function abs_pullback(ΔΩ) + signx = isreal(x) ? sign(x) : x / ifelse(iszero(x), oneunit(Ω), Ω) + return (NoTangent(), signx * real(ΔΩ)) + end + return Ω, abs_pullback +end \ No newline at end of file diff --git a/src/trig.jl b/src/trig.jl new file mode 100644 index 0000000..f7088e3 --- /dev/null +++ b/src/trig.jl @@ -0,0 +1,17 @@ +########################### +#= + Trigonometric Rules for Degrees + + Let dx be differential in radians/dimensionless, dx° be in degrees + df/dx° = df/dx * dx/dx° = df/dx * π/180° +=# +########################### +const DEGREE_QUANTITY = Quantity{<:Number,NoDims,typeof(u"°")} +const TO_RAD = π/180u"°" + +@scalar_rule sin(x::DEGREE_QUANTITY) cos(x) * TO_RAD +@scalar_rule cos(x::DEGREE_QUANTITY) -sin(x) * TO_RAD +@scalar_rule tan(x::DEGREE_QUANTITY) (1 + Ω^2) * TO_RAD +@scalar_rule csc(x::DEGREE_QUANTITY) -Ω * cot(x) * TO_RAD +@scalar_rule sec(x::DEGREE_QUANTITY) Ω * tan(x) * TO_RAD +@scalar_rule cot(x::DEGREE_QUANTITY) -(1 + Ω^2) * TO_RAD diff --git a/test/extras.jl b/test/extras.jl new file mode 100644 index 0000000..64f4692 --- /dev/null +++ b/test/extras.jl @@ -0,0 +1,19 @@ +using Unitful +using UnitfulChainRules +using ChainRulesCore + +@testset "ustrip" begin + x = 5.0u"m" + Ω, pb = rrule(ustrip, x) + + @test Ω == 5.0 + @test last(pb(2.0)) == 2.0/u"m" +end + +@testset "uconvert" begin + x = 30.0u"°" + Ω, pb = rrule(uconvert, u"rad", x) + + @test Ω ≈ (π/6)u"rad" + @test last(pb(1.0)) ≈ π*u"rad"/180u"°" +end \ No newline at end of file diff --git a/test/math.jl b/test/math.jl new file mode 100644 index 0000000..9d44998 --- /dev/null +++ b/test/math.jl @@ -0,0 +1,12 @@ +using Unitful +using UnitfulChainRules + +using ChainRulesCore + +@testset "abs" begin + z = (1 + im)u"W" + Ω, pb = rrule(abs, z) + + @test Ω ≈ sqrt(2)u"W" + @test last(pb(1.0)) ≈ (1 + im)/sqrt(2) +end \ No newline at end of file diff --git a/test/rrules-frules-projection.jl b/test/rrules-frules-projection.jl new file mode 100644 index 0000000..0f0c33a --- /dev/null +++ b/test/rrules-frules-projection.jl @@ -0,0 +1,85 @@ +using Unitful +using UnitfulChainRules + +using ChainRulesCore: frule, rrule, ProjectTo, NoTangent + +using Random +rng = VERSION >= v"1.7" ? Random.Xoshiro(0x0451) : Random.MersenneTwister(0x0451) + +@testset "ProjectTo" begin + real_test(proj, val) = proj(val) == real(val) + complex_test(proj, val) = proj(val) == val + ru = randn(rng) + ruval = ru*u"W" + p_ruval = ProjectTo(ruval) + + cu = randn(rng, ComplexF64) + cuval = cu*u"kg" + p_cuval = ProjectTo(cuval) + + p_real = ProjectTo(ru) + p_complex = ProjectTo(cu) + + δr = randn(rng) + δrval = δr*u"m" + + δc = randn(rng, ComplexF64) + δcval = δc*u"L" + + # Test projection onto real unitful quantities + for δ in (δrval, δcval, ru, cu) + @test real_test(p_ruval, δ) + end + + # Test projection onto complex unitful quantities + for δ in (δrval, δcval, ru, cu) + @test complex_test(p_cuval, δ) + end + + # Projecting Unitful quantities onto real values + @test p_real(δrval) == δrval + @test p_real(δcval) == real(δcval) + + # Projecting Unitful quantities onto complex values + @test p_complex(δrval) == δrval + @test p_complex(δcval) == δcval +end + +@testset "rrules" begin + @testset "Quantity rrule" begin + UT = typeof(1.0*u"W") + x = randn(rng) + δx = randn(rng) + Ω, pb = rrule(UT, x) + @test Ω == x * u"W" + @test pb(δx) == (NoTangent(), δx * u"W") + end + @testset "* rrule" begin + x = randn(rng)*u"W" + y = u"m" + z = u"L" + Ω, pb = rrule(*, x, y, z) + @test Ω == x*y*z + δ = randn(rng) + @test pb(δ) == (NoTangent(), δ*y*z, NoTangent(), NoTangent()) + end +end + +@testset "frules" begin + @testset "Quantity frule" begin + UT = typeof(1.0*u"W") + x = randn(rng) + δx = randn(rng) + X, ∂X = frule((nothing, δx), UT, x) + @test X == x * u"W" + @test ∂X == δx * u"W" + end + @testset "* frule" begin + x = randn(rng)*u"W" + δx = randn(rng)*u"L" + y = u"m" + X, ∂X = frule((nothing, δx, nothing), *, x, y) + @test X == x*y + @test ∂X == δx*y + end +end \ No newline at end of file diff --git a/test/runtests.jl b/test/runtests.jl index 88d16f9..39ec9aa 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,88 +1,15 @@ -using Unitful -using UnitfulChainRules - -using ChainRulesCore: frule, rrule, ProjectTo, NoTangent - -using Random +using SafeTestsets using Test -rng = VERSION >= v"1.7" ? Random.Xoshiro(0x0451) : Random.MersenneTwister(0x0451) - -@testset "ProjectTo" begin - real_test(proj, val) = proj(val) == real(val) - complex_test(proj, val) = proj(val) == val - ru = randn(rng) - ruval = ru*u"W" - p_ruval = ProjectTo(ruval) - - cu = randn(rng, ComplexF64) - cuval = cu*u"kg" - p_cuval = ProjectTo(cuval) - - p_real = ProjectTo(ru) - p_complex = ProjectTo(cu) - - δr = randn(rng) - δrval = δr*u"m" - - δc = randn(rng, ComplexF64) - δcval = δc*u"L" - - # Test projection onto real unitful quantities - for δ in (δrval, δcval, ru, cu) - @test real_test(p_ruval, δ) - end - - # Test projection onto complex unitful quantities - for δ in (δrval, δcval, ru, cu) - @test complex_test(p_cuval, δ) - end - - # Projecting Unitful quantities onto real values - @test p_real(δrval) == δrval - @test p_real(δcval) == real(δcval) - - # Projecting Unitful quantities onto complex values - @test p_complex(δrval) == δrval - @test p_complex(δcval) == δcval +@safetestset "rrules, frules, ProjectTo" begin + include("./rrules-frules-projection.jl") end -@testset "rrules" begin - @testset "Quantity rrule" begin - UT = typeof(1.0*u"W") - x = randn(rng) - δx = randn(rng) - Ω, pb = rrule(UT, x) - @test Ω == x * u"W" - @test pb(δx) == (NoTangent(), δx * u"W") - end - @testset "* rrule" begin - x = randn(rng)*u"W" - y = u"m" - z = u"L" - Ω, pb = rrule(*, x, y, z) - @test Ω == x*y*z - δ = randn(rng) - @test pb(δ) == (NoTangent(), δ*y*z, NoTangent(), NoTangent()) - end +@safetestset "Trig Operations" begin + include("./trig.jl") end -@testset "frules" begin - @testset "Quantity frule" begin - UT = typeof(1.0*u"W") - x = randn(rng) - δx = randn(rng) - X, ∂X = frule((nothing, δx), UT, x) - @test X == x * u"W" - @test ∂X == δx * u"W" - end - @testset "* frule" begin - x = randn(rng)*u"W" - δx = randn(rng)*u"L" - y = u"m" - X, ∂X = frule((nothing, δx, nothing), *, x, y) - @test X == x*y - @test ∂X == δx*y - end -end \ No newline at end of file +@safetestset "Extras" begin + include("./extras.jl") +end diff --git a/test/trig.jl b/test/trig.jl new file mode 100644 index 0000000..1b5f19a --- /dev/null +++ b/test/trig.jl @@ -0,0 +1,30 @@ +using Unitful +using UnitfulChainRules + +using ChainRulesCore + +using Random +rng = VERSION >= v"1.7" ? Random.Xoshiro(0x0451) : Random.MersenneTwister(0x0451) + +dsin(Ω, x) = cos(x) +dcos(Ω, x) = -sin(x) +dtan(Ω, x) = 1 + Ω^2 +dcsc(Ω, x) = - Ω * cot(x) +dsec(Ω, x) = Ω * tan(x) +dcot(Ω, x) = -(1 + Ω^2) + +for (f, df) in ( + (:sin, :dsin), (:cos,:dcos), (:tan,:dtan), (:csc,:dcsc), (:sec,:dsec), (:cot,:dcot) + ) + eval( + quote + @testset "$($f)" begin + x = rand(rng)u"°" + + Ω, pb = rrule($f, x) + @test Ω == $f(x) + @test last(pb(1.0)) ≈ $df(Ω, x) * π/180u"°" + end + end + ) +end \ No newline at end of file