diff --git a/Project.toml b/Project.toml index 94a7c775a..7b6ad3b0d 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "Mooncake" uuid = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6" authors = ["Will Tebbutt, Hong Ge, and contributors"] -version = "0.4.42" +version = "0.4.43" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" @@ -10,6 +10,7 @@ ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" DiffRules = "b552c78f-8df3-52c6-915a-8e097449b14b" DiffTests = "de460e47-3fe3-5279-bb4a-814414816d5d" ExprTools = "e2ba6199-217a-4e67-a87a-7c52f15ade04" +FunctionWrappers = "069b7b12-0de2-55c6-9aab-29f3d0a68a2e" Graphs = "86223c79-3864-5bf0-83f7-82e725a168b6" InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" @@ -48,6 +49,7 @@ DiffRules = "1" DiffTests = "0.1" DynamicPPL = "0.29, 0.30" ExprTools = "0.1" +FunctionWrappers = "1.1.3" Graphs = "1" InteractiveUtils = "1" JET = "0.9" diff --git a/src/Mooncake.jl b/src/Mooncake.jl index 0b5a222ac..6cd360006 100644 --- a/src/Mooncake.jl +++ b/src/Mooncake.jl @@ -19,7 +19,7 @@ import ChainRulesCore using Base: IEEEFloat, unsafe_convert, unsafe_pointer_to_objref, pointer_from_objref, arrayref, - arrayset + arrayset, TwicePrecision, twiceprecision using Base.Experimental: @opaque using Base.Iterators: product using Core: @@ -29,6 +29,7 @@ using Core.Compiler: IRCode, NewInstruction using Core.Intrinsics: pointerref, pointerset using LinearAlgebra.BLAS: @blasfunc, BlasInt, trsm! using LinearAlgebra.LAPACK: getrf!, getrs!, getri!, trtrs!, potrf!, potrs! +using FunctionWrappers: FunctionWrapper # Needs to be defined before various other things. function _foreigncall_ end @@ -82,6 +83,7 @@ include(joinpath("rrules", "blas.jl")) include(joinpath("rrules", "builtins.jl")) include(joinpath("rrules", "fastmath.jl")) include(joinpath("rrules", "foreigncall.jl")) +include(joinpath("rrules", "function_wrappers.jl")) include(joinpath("rrules", "iddict.jl")) include(joinpath("rrules", "lapack.jl")) include(joinpath("rrules", "linear_algebra.jl")) @@ -89,6 +91,7 @@ include(joinpath("rrules", "low_level_maths.jl")) include(joinpath("rrules", "misc.jl")) include(joinpath("rrules", "new.jl")) include(joinpath("rrules", "tasks.jl")) +include(joinpath("rrules", "twice_precision.jl")) @static if VERSION >= v"1.11-rc4" include(joinpath("rrules", "memory.jl")) else diff --git a/src/fwds_rvs_data.jl b/src/fwds_rvs_data.jl index 8fbab992e..428f979f4 100644 --- a/src/fwds_rvs_data.jl +++ b/src/fwds_rvs_data.jl @@ -533,7 +533,7 @@ zero_rdata(p::IEEEFloat) = zero(p) R == NoRData && return :(NoRData()) # T ought to be a `Tangent`. If it's not, something has gone wrong. - !(T <: Tangent) && Expr(:call, error, "Unhandled type $T") + !(T <: Tangent) && return Expr(:call, error, "Unhandled type $T") rdata_field_zeros_exprs = ntuple(fieldcount(P)) do n R_field = rdata_field_type(P, n) if R_field <: PossiblyUninitTangent diff --git a/src/rrules/avoiding_non_differentiable_code.jl b/src/rrules/avoiding_non_differentiable_code.jl index df4fe7b8f..519979390 100644 --- a/src/rrules/avoiding_non_differentiable_code.jl +++ b/src/rrules/avoiding_non_differentiable_code.jl @@ -12,6 +12,7 @@ end @zero_adjoint MinimalCtx Tuple{Type{Float64}, Any, RoundingMode} @zero_adjoint MinimalCtx Tuple{Type{Float32}, Any, RoundingMode} @zero_adjoint MinimalCtx Tuple{Type{Float16}, Any, RoundingMode} +@zero_adjoint MinimalCtx Tuple{typeof(==), Type, Type} function generate_hand_written_rrule!!_test_cases( rng_ctor, ::Val{:avoiding_non_differentiable_code} diff --git a/src/rrules/function_wrappers.jl b/src/rrules/function_wrappers.jl new file mode 100644 index 000000000..09ade8395 --- /dev/null +++ b/src/rrules/function_wrappers.jl @@ -0,0 +1,195 @@ +# Type used to represent tangents of `FunctionWrapper`s. Also used to represent its fdata +# because `FunctionWrapper`s are mutable types. +mutable struct FunctionWrapperTangent{Tfwds_oc} + fwds_wrapper::Tfwds_oc + dobj_ref::Ref +end + +function _construct_types(R, A) + + # Convert signature into a tuple of types. + primal_arg_types = (A.parameters..., ) + + # Signature and OpaqueClosure type for reverse pass. + rvs_sig = Tuple{rdata_type(tangent_type(R))} + primal_rdata_sig = Tuple{map(rdata_type ∘ tangent_type, primal_arg_types)...} + pb_ret_type = Tuple{NoRData, primal_rdata_sig.parameters...} + rvs_oc_type = Core.OpaqueClosure{rvs_sig, pb_ret_type} + + # Signature and OpaqueClosure type for forwards pass. + fwd_sig = Tuple{map(fcodual_type, primal_arg_types)...} + fwd_oc_type = Core.OpaqueClosure{fwd_sig, Tuple{fcodual_type(R), rvs_oc_type}} + return fwd_oc_type, rvs_oc_type, fwd_sig, rvs_sig +end + +function tangent_type(::Type{FunctionWrapper{R, A}}) where {R, A<:Tuple} + return FunctionWrapperTangent{_construct_types(R, A)[1]} +end + +import .TestUtils: has_equal_data_internal +function has_equal_data_internal( + p::P, q::P, equal_undefs::Bool, d::Dict{Tuple{UInt, UInt}, Bool} +) where {P<:FunctionWrapper} + return has_equal_data_internal(p.obj, q.obj, equal_undefs, d) +end +function has_equal_data_internal( + t::T, s::T, equal_undefs::Bool, d::Dict{Tuple{UInt, UInt}, Bool} +) where {T<:FunctionWrapperTangent} + return has_equal_data_internal(t.dobj_ref[], s.dobj_ref[], equal_undefs, d) +end + + + +function _function_wrapper_tangent(R, obj::Tobj, A, obj_tangent) where {Tobj} + + # Analyse types. + _, _, fwd_sig, rvs_sig = _construct_types(R, A) + + # Construct reference to obj_tangent that we can read / write-to. + obj_tangent_ref = Ref{tangent_type(Tobj)}(obj_tangent) + + # Contruct a rule for `obj`, applied to its declared argument types. + rule = build_rrule(Tuple{Tobj, A.parameters...}) + + # Construct stack which can hold pullbacks generated by `rule`. The forwards-pass will + # run `rule` and push the pullback to `pb_stack`. The reverse-pass will pop and run it. + pb_stack = Stack{pullback_type(typeof(rule), (Tobj, A.parameters...))}() + + # Construct reverse-pass. Note: this closes over `pb_stack`. + run_rvs_pass = Base.Experimental.@opaque rvs_sig dy -> begin + obj_rdata, dx... = pop!(pb_stack)(dy) + obj_tangent_ref[] = increment_rdata!!(obj_tangent_ref[], obj_rdata) + return NoRData(), dx... + end + + # Construct fowards-pass. Note: this closes over the reverse-pass and `pb_stack`. + run_fwds_pass = Base.Experimental.@opaque fwd_sig (x...) -> begin + y, pb = rule(CoDual(obj, fdata(obj_tangent_ref[])), x...) + push!(pb_stack, pb) + return y, run_rvs_pass + end + + t = FunctionWrapperTangent(run_fwds_pass, obj_tangent_ref) + return t, obj_tangent_ref +end + +function zero_tangent_internal( + p::FunctionWrapper{R, A}, stackdict::Union{Nothing, IdDict} +) where {R, A} + + # If we've seen this primal before, then we must return that tangent. + haskey(stackdict, p) && return stackdict[p]::tangent_type(typeof(p)) + + # We have not seen this primal before, create it and log it. + obj_tangent = zero_tangent_internal(p.obj[], stackdict) + t, _ = _function_wrapper_tangent(R, p.obj[], A, obj_tangent) + stackdict === nothing || setindex!(stackdict, t, p) + return t +end + +function randn_tangent_internal( + rng::AbstractRNG, p::FunctionWrapper{R, A}, stackdict::Union{Nothing, IdDict} +) where {R, A} + + # If we've seen this primal before, then we must return that tangent. + haskey(stackdict, p) && return stackdict[p]::tangent_type(typeof(p)) + + # We have not seen this primal before, create it and log it. + obj_tangent = randn_tangent_internal(rng, p.obj[], stackdict) + t, _ = _function_wrapper_tangent(R, p.obj[], A, obj_tangent) + stackdict === nothing || setindex!(stackdict, t, p) + return t +end + +function increment!!(t::T, s::T) where {T<:FunctionWrapperTangent} + t.dobj_ref[] = increment!!(t.dobj_ref[], s.dobj_ref[]) + return t +end + +function set_to_zero!!(t::FunctionWrapperTangent) + t.dobj_ref[] = set_to_zero!!(t.dobj_ref[]) + return t +end + +function _add_to_primal(p::FunctionWrapper, t::FunctionWrapperTangent, unsafe::Bool) + return typeof(p)(_add_to_primal(p.obj[], t.dobj_ref[], unsafe)) +end + +function _diff(p::P, q::P) where {R, A, P<:FunctionWrapper{R, A}} + return first(_function_wrapper_tangent(R, p.obj[], A, _diff(p.obj[], q.obj[]))) +end + +_dot(t::T, s::T) where {T<:FunctionWrapperTangent} = _dot(t.dobj_ref[], s.dobj_ref[]) + +function _scale(a::Float64, t::T) where {T<:FunctionWrapperTangent} + return T(t.fwds_wrapper, Ref(_scale(a, t.dobj_ref[]))) +end + +import .TestUtils: populate_address_map!, AddressMap +function populate_address_map!(m::AddressMap, p::FunctionWrapper, t::FunctionWrapperTangent) + k = pointer_from_objref(p) + v = pointer_from_objref(t) + haskey(m, k) && (@assert m[k] == v) + m[k] = v + return m +end + +fdata_type(T::Type{<:FunctionWrapperTangent}) = T +rdata_type(::Type{FunctionWrapperTangent}) = NoRData +tangent_type(F::Type{<:FunctionWrapperTangent}, ::Type{NoRData}) = F +tangent(f::FunctionWrapperTangent, ::NoRData) = f + +_verify_fdata_value(p::FunctionWrapper, t::FunctionWrapperTangent) = nothing + +@is_primitive MinimalCtx Tuple{Type{<:FunctionWrapper}, Any} +function rrule!!(::CoDual{Type{FunctionWrapper{R, A}}}, obj::CoDual{P}) where {R, A, P} + t, obj_tangent_ref = _function_wrapper_tangent(R, obj.x, A, zero_tangent(obj.x, obj.dx)) + function_wrapper_pb(::NoRData) = NoRData(), rdata(obj_tangent_ref[]) + return CoDual(FunctionWrapper{R, A}(obj.x), t), function_wrapper_pb +end + +@is_primitive MinimalCtx Tuple{<:FunctionWrapper, Vararg} +function rrule!!(f::CoDual{<:FunctionWrapper}, x::Vararg{CoDual}) + y, pb = f.dx.fwds_wrapper(x...) + function_wrapper_eval_pb(dy) = pb(dy) + return y, function_wrapper_eval_pb +end + +function generate_hand_written_rrule!!_test_cases(rng_ctor, ::Val{:function_wrappers}) + test_cases = Any[ + (false, :none, nothing, FunctionWrapper{Float64, Tuple{Float64}}, sin), + (false, :none, nothing, FunctionWrapper{Float64, Tuple{Float64}}(sin), 5.0), + ] + memory = Any[] + return test_cases, memory +end + +function generate_derived_rrule!!_test_cases(rng_ctor, ::Val{:function_wrappers}) + test_cases = Any[ + ( + false, :none, nothing, + function(x, y) + p = FunctionWrapper{Float64, Tuple{Float64}}(x -> x * y) + out = 0.0 + for _ in 1:1_000 + out += p(x) + end + return out + end, + 5.0, 4.0, + ), + ( + false, :none, nothing, + function(x::Vector{Float64}, y::Float64) + p = FunctionWrapper{Float64, Tuple{Float64}}(x -> x * y) + out = 0.0 + for _x in x + out += p(_x) + end + return out + end, + randn(100), randn(), + ), + ] + return test_cases, Any[] +end diff --git a/src/rrules/iddict.jl b/src/rrules/iddict.jl index ed8b57a04..74a2ea67b 100644 --- a/src/rrules/iddict.jl +++ b/src/rrules/iddict.jl @@ -40,10 +40,12 @@ function TestUtils.populate_address_map!(m::TestUtils.AddressMap, p::IdDict, t:: foreach(n -> TestUtils.populate_address_map!(m, p[n], t[n]), keys(p)) return m end -function TestUtils.has_equal_data(p::P, q::P; equal_undefs=true) where {P<:IdDict} +function TestUtils.has_equal_data_internal( + p::P, q::P, equal_undefs::Bool, d::Dict{Tuple{UInt, UInt}, Bool} +) where {P<:IdDict} ks = union(keys(p), keys(q)) ks != keys(p) && return false - return all([TestUtils.has_equal_data(p[k], q[k]; equal_undefs) for k in ks]) + return all([TestUtils.has_equal_data_internal(p[k], q[k], equal_undefs, d) for k in ks]) end fdata_type(::Type{T}) where {T<:IdDict} = T diff --git a/src/rrules/twice_precision.jl b/src/rrules/twice_precision.jl new file mode 100644 index 000000000..3cbcbccb0 --- /dev/null +++ b/src/rrules/twice_precision.jl @@ -0,0 +1,377 @@ +# Let x be a TwicePrecision{<:IEEEFloat}, then the `Float64` associated to x is roughly +# `x.hi + x.lo`. + +# +# Implementation of tangent for TwicePrecision. Let `x` be the number that a given +# TwicePrecision represents, and its fields be `hi` and `lo`. Since `x = hi + lo`, we should +# not think of `TwicePrecision` as a struct with two fields, but as a single number. As +# such, we need to be careful to ensure that tangents of `TwicePrecision`s do not depend on +# the values of `hi` and `lo`, but on their sum. +# + +const TwicePrecisionFloat{P<:IEEEFloat} = TwicePrecision{P} +const TWP{P} = TwicePrecisionFloat{P} + +tangent_type(P::Type{<:TWP}) = P + +zero_tangent_internal(::TWP{F}, ::StackDict) where {F} = TWP{F}(zero(F), zero(F)) + +function randn_tangent_internal(rng::AbstractRNG, p::TWP{F}, ::StackDict) where {F} + return TWP{F}(randn(rng, F), randn(rng, F)) +end + +import .TestUtils: has_equal_data_internal +function has_equal_data_internal( + p::P, q::P, ::Bool, ::Dict{Tuple{UInt, UInt}, Bool} +) where {P<:TWP} + return Float64(p) ≈ Float64(q) +end + +increment!!(t::T, s::T) where {T<:TWP} = t + s + +set_to_zero!!(t::TWP) = zero_tangent_internal(t, nothing) + +_add_to_primal(p::P, t::P, ::Bool) where {P<:TWP} = p + t + +_diff(p::P, q::P) where {P<:TWP} = p - q + +_dot(t::P, s::P) where {P<:TWP} = Float64(t) * Float64(s) + +_scale(a::Float64, t::TWP) = a * t + +populate_address_map!(m::AddressMap, ::P, ::P) where {P<:TWP} = m + +fdata_type(::Type{<:TWP}) = NoFData + +rdata_type(P::Type{<:TWP}) = P + +_verify_fdata_value(::P, ::P) where {P<:TWP} = nothing + +_verify_rdata_value(::P, ::P) where {P<:TWP} = nothing + +tangent_type(::Type{NoFData}, T::Type{<:TWP}) = T + +tangent(::NoFData, t::TWP) = t + +zero_rdata(p::TWP) = zero_tangent(p) + +zero_rdata_from_type(P::Type{<:TWP{F}}) where {F} = P(zero(F), zero(F)) + +# +# Rules. These are required for a lot of functionality in this case. +# + +@is_primitive MinimalCtx Tuple{typeof(_new_), <:TWP, IEEEFloat, IEEEFloat} +function rrule!!( + ::CoDual{typeof(_new_)}, ::CoDual{Type{TWP{P}}}, hi::CoDual{P}, lo::CoDual{P} +) where {P<:IEEEFloat} + _new_twice_precision_pb(dy::TWP{P}) = NoRData(), NoRData(), P(dy), P(dy) + return zero_fcodual(_new_(TWP{P}, hi.x, lo.x)), _new_twice_precision_pb +end + +@is_primitive MinimalCtx Tuple{typeof(twiceprecision), IEEEFloat, Integer} +function rrule!!( + ::CoDual{typeof(twiceprecision)}, val::CoDual{P}, nb::CoDual{<:Integer} +) where {P<:IEEEFloat} + twiceprecision_float_pb(dy::TWP{P}) = NoRData(), P(dy), NoRData() + return zero_fcodual(twiceprecision(val.x, nb.x)), twiceprecision_float_pb +end + +@is_primitive MinimalCtx Tuple{typeof(twiceprecision), TWP, Integer} +function rrule!!( + ::CoDual{typeof(twiceprecision)}, val::CoDual{P}, nb::CoDual{<:Integer} +) where {P<:TWP} + twiceprecision_pb(dy::P) = NoRData(), dy, NoRData() + return zero_fcodual(twiceprecision(val.x, nb.x)), twiceprecision_pb +end + +@is_primitive MinimalCtx Tuple{Type{<:IEEEFloat}, TWP} +function rrule!!(::CoDual{Type{P}}, x::CoDual{S}) where {P<:IEEEFloat, S<:TWP} + float_from_twice_precision_pb(dy::P) = NoRData(), S(dy) + return zero_fcodual(P(x.x)), float_from_twice_precision_pb +end + +@is_primitive MinimalCtx Tuple{typeof(-), TWP} +function rrule!!(::CoDual{typeof(-)}, x::CoDual{P}) where {P<:TWP} + negate_twice_precision_pb(dy::P) = NoRData(), -dy + return zero_fcodual(-(x.x)), negate_twice_precision_pb +end + +@is_primitive MinimalCtx Tuple{typeof(+), TWP, IEEEFloat} +function rrule!!(::CoDual{typeof(+)}, x::CoDual{P}, y::CoDual{S}) where {P<:TWP, S<:IEEEFloat} + plus_pullback(dz::P) = NoRData(), dz, S(dz) + return zero_fcodual(x.x + y.x), plus_pullback +end + +@is_primitive(MinimalCtx, Tuple{typeof(+), P, P} where {P<:TWP}) +function rrule!!(::CoDual{typeof(+)}, x::CoDual{P}, y::CoDual{P}) where {P<:TWP} + plus_pullback(dz::P) = NoRData(), dz, dz + return zero_fcodual(x.x + y.x), plus_pullback +end + +@is_primitive MinimalCtx Tuple{typeof(*), TWP, IEEEFloat} +function rrule!!(::CoDual{typeof(*)}, x::CoDual{P}, y::CoDual{S}) where {P<:TWP, S<:IEEEFloat} + _x, _y = x.x, y.x + mul_twice_precision_and_float_pb(dz::P) = NoRData(), dz * _y, S(dz * _x) + return zero_fcodual(_x * _y), mul_twice_precision_and_float_pb +end + +@is_primitive MinimalCtx Tuple{typeof(*), TWP, Integer} +function rrule!!(::CoDual{typeof(*)}, x::CoDual{P}, y::CoDual{<:Integer}) where {P<:TWP} + _y = y.x + mul_twice_precision_and_int_pb(dz::P) = NoRData(), dz * _y, NoRData() + return zero_fcodual(x.x * _y), mul_twice_precision_and_int_pb +end + +@is_primitive MinimalCtx Tuple{typeof(/), TWP, IEEEFloat} +function rrule!!(::CoDual{typeof(/)}, x::CoDual{P}, y::CoDual{S}) where {P<:TWP, S<:IEEEFloat} + _x, _y = x.x, y.x + div_twice_precision_and_float_pb(dz::P) = NoRData(), dz / _y, S(-dz * _x / _y^2) + return zero_fcodual(_x / _y), div_twice_precision_and_float_pb +end + +@is_primitive MinimalCtx Tuple{typeof(/), TWP, Integer} +function rrule!!(::CoDual{typeof(/)}, x::CoDual{P}, y::CoDual{<:Integer}) where {P<:TWP} + _y = y.x + div_twice_precision_and_int_pb(dz::P) = NoRData(), dz / _y, NoRData() + return zero_fcodual(x.x / _y), div_twice_precision_and_int_pb +end + +# Primitives + +@zero_adjoint MinimalCtx Tuple{Type{<:TwicePrecision}, Tuple{Integer, Integer}, Integer} +@zero_adjoint MinimalCtx Tuple{typeof(Base.splitprec), Type, Integer} +@zero_adjoint( + MinimalCtx, + Tuple{typeof(Base.floatrange), Type{<:IEEEFloat}, Integer, Integer, Integer, Integer}, +) +@zero_adjoint( + MinimalCtx, + Tuple{typeof(Base._linspace), Type{<:IEEEFloat}, Integer, Integer, Integer, Integer}, +) + +using Base: range_start_step_length +@is_primitive( + MinimalCtx, Tuple{typeof(range_start_step_length), T, T, Integer} where {T<:IEEEFloat} +) +function rrule!!( + ::CoDual{typeof(range_start_step_length)}, + a::CoDual{T}, + st::CoDual{T}, + len::CoDual{<:Integer}, +) where {T<:IEEEFloat} + pb(dz) = NoRData(), T(dz.data.ref), T(dz.data.step), NoRData() + return zero_fcodual(range_start_step_length(a.x, st.x, len.x)), pb +end + +using Base: unsafe_getindex +const TWPStepRangeLen = StepRangeLen{<:Any, <:TWP, <:TWP} +@is_primitive(MinimalCtx, Tuple{typeof(unsafe_getindex), TWPStepRangeLen, Integer}) +function rrule!!( + ::CoDual{typeof(unsafe_getindex)}, r::CoDual{P}, i::CoDual{<:Integer} +) where {P<:TWPStepRangeLen} + offset = r.x.offset + function unsafe_getindex_pb(dy) + T = rdata_type(tangent_type(P)) + dy_twice_precision = TwicePrecision(dy) + dref = dy_twice_precision + dstep = dy_twice_precision * (i.x - offset) + dr = T((ref=dref, step=dstep, len=NoRData(), offset=NoRData())) + return NoRData(), dr, NoRData() + end + return zero_fcodual(unsafe_getindex(r.x, i.x)), unsafe_getindex_pb +end + +using Base: _getindex_hiprec +@is_primitive(MinimalCtx, Tuple{typeof(_getindex_hiprec), TWPStepRangeLen, Integer}) +function rrule!!( + ::CoDual{typeof(_getindex_hiprec)}, r::CoDual{P}, i::CoDual{<:Integer} +) where {P<:TWPStepRangeLen} + offset = r.x.offset + function unsafe_getindex_pb(dy) + T = rdata_type(tangent_type(P)) + dref = dy + dstep = dy * (i.x - offset) + dr = T((ref=dref, step=dstep, len=NoRData(), offset=NoRData())) + return NoRData(), dr, NoRData() + end + return zero_fcodual(_getindex_hiprec(r.x, i.x)), unsafe_getindex_pb +end + +@is_primitive MinimalCtx Tuple{typeof(:), P, P, P} where {P<:IEEEFloat} +function rrule!!( + ::CoDual{typeof(:)}, start::CoDual{P}, step::CoDual{P}, stop::CoDual{P} +) where {P<:IEEEFloat} + colon_pb(dy::RData) = NoRData(), P(dy.data.ref), P(dy.data.step), zero(P) + return zero_fcodual((:)(start.x, step.x, stop.x)), colon_pb +end + +@is_primitive MinimalCtx Tuple{typeof(sum), TWPStepRangeLen} +function rrule!!(::CoDual{typeof(sum)}, x::CoDual{P}) where {P<:TWPStepRangeLen} + l = x.x.len + offset = x.x.offset + function sum_pb(dy::Float64) + R = rdata_type(tangent_type(P)) + dref = TwicePrecision(l * dy) + dstep = TwicePrecision(dy * (0.5 * l * (l + 1) - l * offset)) + dx = R((ref=dref, step=dstep, len=NoRData(), offset=NoRData())) + return NoRData(), dx + end + return zero_fcodual(sum(x.x)), sum_pb +end + +@is_primitive( + MinimalCtx, + Tuple{typeof(Base.range_start_stop_length), P, P, Integer} where {P<:IEEEFloat}, +) +function rrule!!( + ::CoDual{typeof(Base.range_start_stop_length)}, + start::CoDual{P}, + stop::CoDual{P}, + length::CoDual{<:Integer}, +) where {P<:IEEEFloat} + l = (length.x - 1) + function range_start_stop_length_pb(dy::RData) + dstart = P(dy.data.ref) - P(dy.data.step) / l + dstop = P(dy.data.step) / l + return NoRData(), dstart, dstop, NoRData() + end + y = zero_fcodual(Base.range_start_stop_length(start.x, stop.x, length.x)) + return y, range_start_stop_length_pb +end + +@static if VERSION >= v"1.11" + +@is_primitive MinimalCtx Tuple{typeof(Base._exp_allowing_twice64), TwicePrecision{Float64}} +function rrule!!( + ::CoDual{typeof(Base._exp_allowing_twice64)}, x::CoDual{TwicePrecision{Float64}} +) + y = Base._exp_allowing_twice64(x.x) + _exp_allowing_twice64_pb(dy::Float64) = NoRData(), TwicePrecision(dy * y) + return zero_fcodual(y), _exp_allowing_twice64_pb +end + +@is_primitive(MinimalCtx, Tuple{typeof(Base._log_twice64_unchecked), Float64}) +function rrule!!(::CoDual{typeof(Base._log_twice64_unchecked)}, x::CoDual{Float64}) + _x = x.x + _log_twice64_pb(dy::TwicePrecision{Float64}) = NoRData(), Float64(dy) / _x + return zero_fcodual(Base._log_twice64_unchecked(_x)), _log_twice64_pb +end + +end + +function generate_hand_written_rrule!!_test_cases(rng_ctor, ::Val{:twice_precision}) + test_cases = Any[ + ( + false, :stability_and_allocs, nothing, + _new_, TwicePrecisionFloat{Float64}, 5.0, 4.0 + ), + (false, :stability_and_allocs, nothing, twiceprecision, 5.0, 4), + (false, :stability_and_allocs, nothing, twiceprecision, TwicePrecision(5.0), 4), + (false, :stability_and_allocs, nothing, Float64, TwicePrecision(5.0, 3.0)), + (false, :stability_and_allocs, nothing, -, TwicePrecision(5.0, 3.0)), + (false, :stability_and_allocs, nothing, +, TwicePrecision(5.0, 3.0), 4.0), + ( + false, :stability_and_allocs, nothing, + +, TwicePrecision(5.0, 3.0), TwicePrecision(4.0, 5.0), + ), + (false, :stability_and_allocs, nothing, *, TwicePrecision(5.0, 1e-12), 3.0), + (false, :stability_and_allocs, nothing, *, TwicePrecision(5.0, 1e-12), 3), + (false, :stability_and_allocs, nothing, /, TwicePrecision(5.0, 1e-12), 3.0), + (false, :stability_and_allocs, nothing, /, TwicePrecision(5.0, 1e-12), 3), + + (false, :stability_and_allocs, nothing, Base.splitprec, Float64, 5), + (false, :stability_and_allocs, nothing, Base.splitprec, Float32, 5), + (false, :stability_and_allocs, nothing, Base.splitprec, Float16, 5), + + (false, :stability_and_allocs, nothing, Base.floatrange, Float64, 5, 6, 7, 8), + (false, :stability_and_allocs, nothing, Base._linspace, Float64, 5, 6, 7, 8), + (false, :stability_and_allocs, nothing, Base.range_start_step_length, 5.0, 6.0, 10), + ( + false, :stability_and_allocs, nothing, + Base.range_start_step_length, 5.0, Float64(π), 10, + ), + ( + false, :stability_and_allocs, nothing, + unsafe_getindex, + StepRangeLen(TwicePrecision(-0.45), TwicePrecision(0.98), 10, 3), + 5, + ), + ( + false, :stability_and_allocs, nothing, + _getindex_hiprec, + StepRangeLen(TwicePrecision(-0.45), TwicePrecision(0.98), 10, 3), + 5, + ), + (false, :stability_and_allocs, nothing, (:), -0.1, 0.99, 5.1), + (false, :stability_and_allocs, nothing, sum, range(-0.1, 9.9; length=51)), + ( + false, :stability_and_allocs, nothing, + Base.range_start_stop_length, -0.5, 11.7, 7, + ), + ( + false, :stability_and_allocs, nothing, + Base.range_start_stop_length, -0.5, -11.7, 11, + ), + ] + @static if VERSION >= v"1.11" + extra_test_cases = Any[ + ( + false, :stability_and_allocs, nothing, + Base._exp_allowing_twice64, TwicePrecision(2.0), + ), + (false, :stability_and_allocs, nothing, Base._log_twice64_unchecked, 3.0), + ] + test_cases = vcat(test_cases, extra_test_cases) + end + memory = Any[] + return test_cases, memory +end + +function generate_derived_rrule!!_test_cases(rng_ctor, ::Val{:twice_precision}) + test_cases = Any[ + + # Functionality in base/twiceprecision.jl + (false, :allocs, nothing, TwicePrecision{Float64}, 5.0, 0.3), + (false, :allocs, nothing, (x, y) -> Float64(TwicePrecision{Float64}(x, y)), 5.0, 0.3), + (false, :allocs, nothing, TwicePrecision, 5.0, 0.3), + (false, :allocs, nothing, (x, y) -> Float64(TwicePrecision(x, y)), 5.0, 0.3), + (false, :allocs, nothing, TwicePrecision{Float64}, 5.0), + (false, :allocs, nothing, x -> Float64(TwicePrecision{Float64}(x)), 5.0), + (false, :allocs, nothing, TwicePrecision, 5.0), + (false, :allocs, nothing, x -> Float64(TwicePrecision(x)), 5.0), + (false, :allocs, nothing, TwicePrecision{Float64}, 5), + (false, :allocs, nothing, x -> Float64(TwicePrecision{Float64}(x)), 5), + (false, :none, nothing, TwicePrecision{Float64}, (5, 4)), + (false, :none, nothing, x -> Float64(TwicePrecision{Float64}(x)), (5, 4)), + (false, :none, nothing, TwicePrecision{Float64}, (5, 4), 3), + (false, :none, nothing, (x, y) -> Float64(TwicePrecision{Float64}(x, y)), (5, 4), 3), + (false, :allocs, nothing, +, TwicePrecision(5.0), TwicePrecision(4.0)), + (false, :allocs, nothing, +, 5.0, TwicePrecision(4.0)), + (false, :allocs, nothing, +, TwicePrecision(5.0), 4.0), + (false, :allocs, nothing, -, TwicePrecision(5.0), TwicePrecision(4.0)), + (false, :allocs, nothing, -, 5.0, TwicePrecision(4.0)), + (false, :allocs, nothing, -, TwicePrecision(5.0), 4.0), + (false, :allocs, nothing, *, 3.0, TwicePrecision(5.0, 1e-12)), + (false, :allocs, nothing, *, 3, TwicePrecision(5.0, 1e-12)), + ( + false, :allocs, nothing, + getindex, + StepRangeLen(TwicePrecision(-0.45), TwicePrecision(0.98), 10, 3), + 2:2:6, + ), + ( + false, :allocs, nothing, + +, range(0.0, 5.0; length=44), range(-33.0, 4.5; length=44), + ), + + # Functionality in base/range.jl + (false, :allocs, nothing, range, 0.0, 5.6), + (false, :allocs, nothing, (lb, ub) -> range(lb, ub; length=10), -0.45, 9.5), + ] + @static if VERSION >= v"1.11" + push!(test_cases, (false, :allocs, nothing, Base._logrange_extra, 1.1, 3.5, 5)) + push!(test_cases, (false, :allocs, nothing, logrange, 5.0, 10.0, 11)) + end + return test_cases, Any[] +end diff --git a/src/tangents.jl b/src/tangents.jl index b77c03b4b..f28564f22 100644 --- a/src/tangents.jl +++ b/src/tangents.jl @@ -436,6 +436,8 @@ function zero_tangent(x::P) where {P} return zero_tangent_internal(x, isbitstype(P) ? nothing : IdDict()) end +const StackDict = Union{Nothing, IdDict} + # the `stackdict` naming following convention of Julia's `deepcopy` and `deepcopy_internal` # https://github.com/JuliaLang/julia/blob/48d4fd48430af58502699fdf3504b90589df3852/base/deepcopy.jl#L35 @inline zero_tangent_internal(::Union{Int8, Int16, Int32, Int64, Int128}, ::Any) = NoTangent() diff --git a/src/test_utils.jl b/src/test_utils.jl index ee91cf0f4..4b77b9a24 100644 --- a/src/test_utils.jl +++ b/src/test_utils.jl @@ -791,7 +791,6 @@ function test_tangent_consistency(rng::AbstractRNG, p::P; interface_only=false) # Verify that operations required for finite difference testing to run, and produce the # correct output type. - @test _add_to_primal(p, t) isa P @test _add_to_primal(p, t, true) isa P @test _diff(p, p) isa T @test _dot(t, t) isa Float64 diff --git a/test/front_matter.jl b/test/front_matter.jl index 276605027..976fe2537 100644 --- a/test/front_matter.jl +++ b/test/front_matter.jl @@ -12,11 +12,12 @@ using import ChainRulesCore -using Base: unsafe_load, pointer_from_objref, IEEEFloat +using Base: unsafe_load, pointer_from_objref, IEEEFloat, TwicePrecision using Base.Iterators: product using Core: bitcast, svec, ReturnNode, PhiNode, PiNode, GotoIfNot, GotoNode, SSAValue, Argument using Core.Intrinsics: pointerref, pointerset +using FunctionWrappers: FunctionWrapper using Mooncake: CC, diff --git a/test/integration_testing/temporalgps/temporalgps.jl b/test/integration_testing/temporalgps/temporalgps.jl index 280b8fa86..3e03fdec1 100644 --- a/test/integration_testing/temporalgps/temporalgps.jl +++ b/test/integration_testing/temporalgps/temporalgps.jl @@ -13,6 +13,7 @@ temporalgps_logpdf_tester(k, x, y, s) = logpdf(build_gp(k)(x, s), y) xs = Any[ collect(range(-5.0; step=0.1, length=1_000)), RegularSpacing(0.0, 0.1, 1_000), + range(-5.0; step=0.1, length=1_000), ] base_kernels = Any[Matern12Kernel(), Matern32Kernel()] kernels = vcat(base_kernels, [with_lengthscale(k, 1.1) for k in base_kernels]) diff --git a/test/rrules/function_wrappers.jl b/test/rrules/function_wrappers.jl new file mode 100644 index 000000000..fa9c16b27 --- /dev/null +++ b/test/rrules/function_wrappers.jl @@ -0,0 +1,12 @@ +@testset "function_wrappers" begin + rng = Xoshiro(123) + _data = Ref{Float64}(5.0) + @testset "$p" for p in Any[ + FunctionWrapper{Float64, Tuple{Float64}}(sin), + FunctionWrapper{Float64, Tuple{Float64}}(x -> x * _data[]), + ] + TestUtils.test_tangent_consistency(rng, p) + TestUtils.test_fwds_rvs_data(rng, p) + end + TestUtils.run_rrule!!_test_cases(StableRNG, Val(:function_wrappers)) +end diff --git a/test/rrules/twice_precision.jl b/test/rrules/twice_precision.jl new file mode 100644 index 000000000..d3ea969c1 --- /dev/null +++ b/test/rrules/twice_precision.jl @@ -0,0 +1,7 @@ +@testset "twice_precision" begin + rng = sr(123) + p = Base.TwicePrecision{Float64}(5.0, 4.0) + TestUtils.test_tangent_consistency(rng, p) + TestUtils.test_fwds_rvs_data(rng, p) + TestUtils.run_rrule!!_test_cases(StableRNG, Val(:twice_precision)) +end diff --git a/test/runtests.jl b/test/runtests.jl index a3712fb22..eea7bdc62 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -36,6 +36,8 @@ include("front_matter.jl") include(joinpath("rrules", "fastmath.jl")) @info "foreigncall" include(joinpath("rrules", "foreigncall.jl")) + @info "function_wrappers" + include(joinpath("rrules", "function_wrappers.jl")) @info "iddict" include(joinpath("rrules", "iddict.jl")) @info "lapack" @@ -50,6 +52,8 @@ include("front_matter.jl") include(joinpath("rrules", "new.jl")) @info "tasks" include(joinpath("rrules", "tasks.jl")) + @info "twice_precision" + include(joinpath("rrules", "twice_precision.jl")) @static if VERSION >= v"1.11.0-rc4" @info "memory" include(joinpath("rrules", "memory.jl"))