Skip to content

Commit

Permalink
Proper FunctionWrappers Support (#367)
Browse files Browse the repository at this point in the history
* Initial work on FunctionWrappers integration

* Bump patch version

* Some more work

* Make tests pass

* Add rule to avoid differentiating type comparison

* Add StackDict type

* Fix bug

* Fix bug

* Remove call to function which may not be possible to call

* Finish off FunctionWrappers

* TwicePrecision functionality

* Add integration test for TemporalGPs with StepRangeLen

* Update includes etc

* Fix method ambiguity

* Test + add rules for LogRange-related functionality

* Only do logrange stuff on 1.11

* Fix test error

* Bump patch
  • Loading branch information
willtebbutt authored Nov 17, 2024
1 parent f182cdf commit c4fbfc8
Show file tree
Hide file tree
Showing 14 changed files with 613 additions and 7 deletions.
4 changes: 3 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "Mooncake"
uuid = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6"
authors = ["Will Tebbutt, Hong Ge, and contributors"]
version = "0.4.42"
version = "0.4.43"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Expand All @@ -10,6 +10,7 @@ ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
DiffRules = "b552c78f-8df3-52c6-915a-8e097449b14b"
DiffTests = "de460e47-3fe3-5279-bb4a-814414816d5d"
ExprTools = "e2ba6199-217a-4e67-a87a-7c52f15ade04"
FunctionWrappers = "069b7b12-0de2-55c6-9aab-29f3d0a68a2e"
Graphs = "86223c79-3864-5bf0-83f7-82e725a168b6"
InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Expand Down Expand Up @@ -48,6 +49,7 @@ DiffRules = "1"
DiffTests = "0.1"
DynamicPPL = "0.29, 0.30"
ExprTools = "0.1"
FunctionWrappers = "1.1.3"
Graphs = "1"
InteractiveUtils = "1"
JET = "0.9"
Expand Down
5 changes: 4 additions & 1 deletion src/Mooncake.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ import ChainRulesCore

using Base:
IEEEFloat, unsafe_convert, unsafe_pointer_to_objref, pointer_from_objref, arrayref,
arrayset
arrayset, TwicePrecision, twiceprecision
using Base.Experimental: @opaque
using Base.Iterators: product
using Core:
Expand All @@ -29,6 +29,7 @@ using Core.Compiler: IRCode, NewInstruction
using Core.Intrinsics: pointerref, pointerset
using LinearAlgebra.BLAS: @blasfunc, BlasInt, trsm!
using LinearAlgebra.LAPACK: getrf!, getrs!, getri!, trtrs!, potrf!, potrs!
using FunctionWrappers: FunctionWrapper

# Needs to be defined before various other things.
function _foreigncall_ end
Expand Down Expand Up @@ -82,13 +83,15 @@ include(joinpath("rrules", "blas.jl"))
include(joinpath("rrules", "builtins.jl"))
include(joinpath("rrules", "fastmath.jl"))
include(joinpath("rrules", "foreigncall.jl"))
include(joinpath("rrules", "function_wrappers.jl"))
include(joinpath("rrules", "iddict.jl"))
include(joinpath("rrules", "lapack.jl"))
include(joinpath("rrules", "linear_algebra.jl"))
include(joinpath("rrules", "low_level_maths.jl"))
include(joinpath("rrules", "misc.jl"))
include(joinpath("rrules", "new.jl"))
include(joinpath("rrules", "tasks.jl"))
include(joinpath("rrules", "twice_precision.jl"))
@static if VERSION >= v"1.11-rc4"
include(joinpath("rrules", "memory.jl"))
else
Expand Down
2 changes: 1 addition & 1 deletion src/fwds_rvs_data.jl
Original file line number Diff line number Diff line change
Expand Up @@ -533,7 +533,7 @@ zero_rdata(p::IEEEFloat) = zero(p)
R == NoRData && return :(NoRData())

# T ought to be a `Tangent`. If it's not, something has gone wrong.
!(T <: Tangent) && Expr(:call, error, "Unhandled type $T")
!(T <: Tangent) && return Expr(:call, error, "Unhandled type $T")
rdata_field_zeros_exprs = ntuple(fieldcount(P)) do n
R_field = rdata_field_type(P, n)
if R_field <: PossiblyUninitTangent
Expand Down
1 change: 1 addition & 0 deletions src/rrules/avoiding_non_differentiable_code.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ end
@zero_adjoint MinimalCtx Tuple{Type{Float64}, Any, RoundingMode}
@zero_adjoint MinimalCtx Tuple{Type{Float32}, Any, RoundingMode}
@zero_adjoint MinimalCtx Tuple{Type{Float16}, Any, RoundingMode}
@zero_adjoint MinimalCtx Tuple{typeof(==), Type, Type}

function generate_hand_written_rrule!!_test_cases(
rng_ctor, ::Val{:avoiding_non_differentiable_code}
Expand Down
195 changes: 195 additions & 0 deletions src/rrules/function_wrappers.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,195 @@
# Type used to represent tangents of `FunctionWrapper`s. Also used to represent its fdata
# because `FunctionWrapper`s are mutable types.
mutable struct FunctionWrapperTangent{Tfwds_oc}
fwds_wrapper::Tfwds_oc
dobj_ref::Ref
end

function _construct_types(R, A)

# Convert signature into a tuple of types.
primal_arg_types = (A.parameters..., )

# Signature and OpaqueClosure type for reverse pass.
rvs_sig = Tuple{rdata_type(tangent_type(R))}
primal_rdata_sig = Tuple{map(rdata_type tangent_type, primal_arg_types)...}
pb_ret_type = Tuple{NoRData, primal_rdata_sig.parameters...}
rvs_oc_type = Core.OpaqueClosure{rvs_sig, pb_ret_type}

# Signature and OpaqueClosure type for forwards pass.
fwd_sig = Tuple{map(fcodual_type, primal_arg_types)...}
fwd_oc_type = Core.OpaqueClosure{fwd_sig, Tuple{fcodual_type(R), rvs_oc_type}}
return fwd_oc_type, rvs_oc_type, fwd_sig, rvs_sig
end

function tangent_type(::Type{FunctionWrapper{R, A}}) where {R, A<:Tuple}
return FunctionWrapperTangent{_construct_types(R, A)[1]}
end

import .TestUtils: has_equal_data_internal
function has_equal_data_internal(
p::P, q::P, equal_undefs::Bool, d::Dict{Tuple{UInt, UInt}, Bool}
) where {P<:FunctionWrapper}
return has_equal_data_internal(p.obj, q.obj, equal_undefs, d)
end
function has_equal_data_internal(
t::T, s::T, equal_undefs::Bool, d::Dict{Tuple{UInt, UInt}, Bool}
) where {T<:FunctionWrapperTangent}
return has_equal_data_internal(t.dobj_ref[], s.dobj_ref[], equal_undefs, d)
end



function _function_wrapper_tangent(R, obj::Tobj, A, obj_tangent) where {Tobj}

# Analyse types.
_, _, fwd_sig, rvs_sig = _construct_types(R, A)

# Construct reference to obj_tangent that we can read / write-to.
obj_tangent_ref = Ref{tangent_type(Tobj)}(obj_tangent)

# Contruct a rule for `obj`, applied to its declared argument types.
rule = build_rrule(Tuple{Tobj, A.parameters...})

# Construct stack which can hold pullbacks generated by `rule`. The forwards-pass will
# run `rule` and push the pullback to `pb_stack`. The reverse-pass will pop and run it.
pb_stack = Stack{pullback_type(typeof(rule), (Tobj, A.parameters...))}()

# Construct reverse-pass. Note: this closes over `pb_stack`.
run_rvs_pass = Base.Experimental.@opaque rvs_sig dy -> begin
obj_rdata, dx... = pop!(pb_stack)(dy)
obj_tangent_ref[] = increment_rdata!!(obj_tangent_ref[], obj_rdata)
return NoRData(), dx...
end

# Construct fowards-pass. Note: this closes over the reverse-pass and `pb_stack`.
run_fwds_pass = Base.Experimental.@opaque fwd_sig (x...) -> begin
y, pb = rule(CoDual(obj, fdata(obj_tangent_ref[])), x...)
push!(pb_stack, pb)
return y, run_rvs_pass
end

t = FunctionWrapperTangent(run_fwds_pass, obj_tangent_ref)
return t, obj_tangent_ref
end

function zero_tangent_internal(
p::FunctionWrapper{R, A}, stackdict::Union{Nothing, IdDict}
) where {R, A}

# If we've seen this primal before, then we must return that tangent.
haskey(stackdict, p) && return stackdict[p]::tangent_type(typeof(p))

# We have not seen this primal before, create it and log it.
obj_tangent = zero_tangent_internal(p.obj[], stackdict)
t, _ = _function_wrapper_tangent(R, p.obj[], A, obj_tangent)
stackdict === nothing || setindex!(stackdict, t, p)
return t
end

function randn_tangent_internal(
rng::AbstractRNG, p::FunctionWrapper{R, A}, stackdict::Union{Nothing, IdDict}
) where {R, A}

# If we've seen this primal before, then we must return that tangent.
haskey(stackdict, p) && return stackdict[p]::tangent_type(typeof(p))

# We have not seen this primal before, create it and log it.
obj_tangent = randn_tangent_internal(rng, p.obj[], stackdict)
t, _ = _function_wrapper_tangent(R, p.obj[], A, obj_tangent)
stackdict === nothing || setindex!(stackdict, t, p)
return t
end

function increment!!(t::T, s::T) where {T<:FunctionWrapperTangent}
t.dobj_ref[] = increment!!(t.dobj_ref[], s.dobj_ref[])
return t
end

function set_to_zero!!(t::FunctionWrapperTangent)
t.dobj_ref[] = set_to_zero!!(t.dobj_ref[])
return t
end

function _add_to_primal(p::FunctionWrapper, t::FunctionWrapperTangent, unsafe::Bool)
return typeof(p)(_add_to_primal(p.obj[], t.dobj_ref[], unsafe))
end

function _diff(p::P, q::P) where {R, A, P<:FunctionWrapper{R, A}}
return first(_function_wrapper_tangent(R, p.obj[], A, _diff(p.obj[], q.obj[])))
end

_dot(t::T, s::T) where {T<:FunctionWrapperTangent} = _dot(t.dobj_ref[], s.dobj_ref[])

function _scale(a::Float64, t::T) where {T<:FunctionWrapperTangent}
return T(t.fwds_wrapper, Ref(_scale(a, t.dobj_ref[])))
end

import .TestUtils: populate_address_map!, AddressMap
function populate_address_map!(m::AddressMap, p::FunctionWrapper, t::FunctionWrapperTangent)
k = pointer_from_objref(p)
v = pointer_from_objref(t)
haskey(m, k) && (@assert m[k] == v)
m[k] = v
return m
end

fdata_type(T::Type{<:FunctionWrapperTangent}) = T
rdata_type(::Type{FunctionWrapperTangent}) = NoRData
tangent_type(F::Type{<:FunctionWrapperTangent}, ::Type{NoRData}) = F
tangent(f::FunctionWrapperTangent, ::NoRData) = f

_verify_fdata_value(p::FunctionWrapper, t::FunctionWrapperTangent) = nothing

@is_primitive MinimalCtx Tuple{Type{<:FunctionWrapper}, Any}
function rrule!!(::CoDual{Type{FunctionWrapper{R, A}}}, obj::CoDual{P}) where {R, A, P}
t, obj_tangent_ref = _function_wrapper_tangent(R, obj.x, A, zero_tangent(obj.x, obj.dx))
function_wrapper_pb(::NoRData) = NoRData(), rdata(obj_tangent_ref[])
return CoDual(FunctionWrapper{R, A}(obj.x), t), function_wrapper_pb
end

@is_primitive MinimalCtx Tuple{<:FunctionWrapper, Vararg}
function rrule!!(f::CoDual{<:FunctionWrapper}, x::Vararg{CoDual})
y, pb = f.dx.fwds_wrapper(x...)
function_wrapper_eval_pb(dy) = pb(dy)
return y, function_wrapper_eval_pb
end

function generate_hand_written_rrule!!_test_cases(rng_ctor, ::Val{:function_wrappers})
test_cases = Any[
(false, :none, nothing, FunctionWrapper{Float64, Tuple{Float64}}, sin),
(false, :none, nothing, FunctionWrapper{Float64, Tuple{Float64}}(sin), 5.0),
]
memory = Any[]
return test_cases, memory
end

function generate_derived_rrule!!_test_cases(rng_ctor, ::Val{:function_wrappers})
test_cases = Any[
(
false, :none, nothing,
function(x, y)
p = FunctionWrapper{Float64, Tuple{Float64}}(x -> x * y)
out = 0.0
for _ in 1:1_000
out += p(x)
end
return out
end,
5.0, 4.0,
),
(
false, :none, nothing,
function(x::Vector{Float64}, y::Float64)
p = FunctionWrapper{Float64, Tuple{Float64}}(x -> x * y)
out = 0.0
for _x in x
out += p(_x)
end
return out
end,
randn(100), randn(),
),
]
return test_cases, Any[]
end
6 changes: 4 additions & 2 deletions src/rrules/iddict.jl
Original file line number Diff line number Diff line change
Expand Up @@ -40,10 +40,12 @@ function TestUtils.populate_address_map!(m::TestUtils.AddressMap, p::IdDict, t::
foreach(n -> TestUtils.populate_address_map!(m, p[n], t[n]), keys(p))
return m
end
function TestUtils.has_equal_data(p::P, q::P; equal_undefs=true) where {P<:IdDict}
function TestUtils.has_equal_data_internal(
p::P, q::P, equal_undefs::Bool, d::Dict{Tuple{UInt, UInt}, Bool}
) where {P<:IdDict}
ks = union(keys(p), keys(q))
ks != keys(p) && return false
return all([TestUtils.has_equal_data(p[k], q[k]; equal_undefs) for k in ks])
return all([TestUtils.has_equal_data_internal(p[k], q[k], equal_undefs, d) for k in ks])
end

fdata_type(::Type{T}) where {T<:IdDict} = T
Expand Down
Loading

2 comments on commit c4fbfc8

@willtebbutt
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator register()

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/119629

Tip: Release Notes

Did you know you can add release notes too? Just add markdown formatted text underneath the comment after the text
"Release notes:" and it will be added to the registry PR, and if TagBot is installed it will also be added to the
release that TagBot creates. i.e.

@JuliaRegistrator register

Release notes:

## Breaking changes

- blah

To add them here just re-invoke and the PR will be updated.

Tagging

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.4.43 -m "<description of version>" c4fbfc8b18d8c424d77be1f4f6db3ab04317ef9d
git push origin v0.4.43

Please sign in to comment.