Skip to content

Commit

Permalink
uconvert rrule
Browse files Browse the repository at this point in the history
  • Loading branch information
Sam Buercklin committed Jul 28, 2022
1 parent 80e2327 commit 18d5597
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 1 deletion.
2 changes: 1 addition & 1 deletion src/UnitfulChainRules.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
module UnitfulChainRules

using Unitful
using Unitful: Quantity, Units, NoDims
using Unitful: Quantity, Units, NoDims, FreeUnits
using ChainRulesCore: NoTangent, @scalar_rule
import ChainRulesCore: rrule, frule, ProjectTo

Expand Down
12 changes: 12 additions & 0 deletions src/extras.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,15 @@ function rrule(::typeof(ustrip), x::Quantity{T,D,U}) where {T,D,U}

return ustripped, ustrip_pb
end

function rrule(::typeof(uconvert), u::FreeUnits{N,D,A}, x::TX) where {N,D,A,TX}
x_convert = uconvert(u, x)
conversion = uconvert(u, oneunit(x)) / oneunit(x)
project_x = ProjectTo(x)

function uconvert_pb(Δ)
return (NoTangent(), NoTangent(), project_x* conversion))
end

return x_convert, uconvert_pb
end
8 changes: 8 additions & 0 deletions test/extras.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,12 @@ using ChainRulesCore

@test Ω == 5.0
@test last(pb(2.0)) == 2.0/u"m"
end

@testset "uconvert" begin
x = 30.0u"°"
Ω, pb = rrule(uconvert, u"rad", x)

@test Ω /6)u"rad"
@test last(pb(1.0)) π*u"rad"/180u"°"
end

0 comments on commit 18d5597

Please sign in to comment.