diff --git a/src/UnitfulChainRules.jl b/src/UnitfulChainRules.jl index db52c4c..ebf442e 100644 --- a/src/UnitfulChainRules.jl +++ b/src/UnitfulChainRules.jl @@ -1,7 +1,7 @@ module UnitfulChainRules using Unitful -using Unitful: Quantity, Units, NoDims +using Unitful: Quantity, Units, NoDims, FreeUnits using ChainRulesCore: NoTangent, @scalar_rule import ChainRulesCore: rrule, frule, ProjectTo diff --git a/src/extras.jl b/src/extras.jl index f2b122d..e590576 100644 --- a/src/extras.jl +++ b/src/extras.jl @@ -11,3 +11,15 @@ function rrule(::typeof(ustrip), x::Quantity{T,D,U}) where {T,D,U} return ustripped, ustrip_pb end + +function rrule(::typeof(uconvert), u::FreeUnits{N,D,A}, x::TX) where {N,D,A,TX} + x_convert = uconvert(u, x) + conversion = uconvert(u, oneunit(x)) / oneunit(x) + project_x = ProjectTo(x) + + function uconvert_pb(Δ) + return (NoTangent(), NoTangent(), project_x(Δ * conversion)) + end + + return x_convert, uconvert_pb +end \ No newline at end of file diff --git a/test/extras.jl b/test/extras.jl index d8d3f06..64f4692 100644 --- a/test/extras.jl +++ b/test/extras.jl @@ -8,4 +8,12 @@ using ChainRulesCore @test Ω == 5.0 @test last(pb(2.0)) == 2.0/u"m" +end + +@testset "uconvert" begin + x = 30.0u"°" + Ω, pb = rrule(uconvert, u"rad", x) + + @test Ω ≈ (π/6)u"rad" + @test last(pb(1.0)) ≈ π*u"rad"/180u"°" end \ No newline at end of file