diff --git a/Project.toml b/Project.toml index a26221388..274421ecb 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "Mooncake" uuid = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6" authors = ["Will Tebbutt, Hong Ge, and contributors"] -version = "0.4.33" +version = "0.4.34" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" diff --git a/src/interpreter/s2s_reverse_mode_ad.jl b/src/interpreter/s2s_reverse_mode_ad.jl index e01c62195..d93f3fa3e 100644 --- a/src/interpreter/s2s_reverse_mode_ad.jl +++ b/src/interpreter/s2s_reverse_mode_ad.jl @@ -794,6 +794,8 @@ end _is_primitive(C::Type, mi::Core.MethodInstance) = is_primitive(C, mi.specTypes) _is_primitive(C::Type, sig::Type) = is_primitive(C, sig) +const RuleMC{A, R} = MistyClosure{OpaqueClosure{A, R}} + # Compute the concrete type of the rule that will be returned from `build_rrule`. This is # important for performance in dynamic dispatch, and to ensure that recursion works # properly. @@ -808,20 +810,30 @@ function rule_type(interp::MooncakeInterpreter{C}, sig_or_mi; debug_mode) where isva, _ = is_vararg_and_sparam_names(sig_or_mi) arg_types = map(_type, ir.argtypes) - primal_sig = Tuple{arg_types...} + sig = Tuple{arg_types...} arg_fwds_types = Tuple{map(fcodual_type, arg_types)...} arg_rvs_types = Tuple{map(rdata_type ∘ tangent_type, arg_types)...} rvs_return_type = rdata_type(tangent_type(Treturn)) - pb_type = MistyClosure{OpaqueClosure{Tuple{rvs_return_type}, arg_rvs_types}} - - Tderived_rule = DerivedRule{ - primal_sig, - MistyClosure{OpaqueClosure{arg_fwds_types, fcodual_type(Treturn)}}, - Pullback{primal_sig, Base.RefValue{pb_type}, Val{isva}, nvargs(isva, primal_sig)}, - Val{isva}, - Val{length(ir.argtypes)}, - } - return debug_mode ? DebugRRule{Tderived_rule} : Tderived_rule + pb_oc_type = MistyClosure{OpaqueClosure{Tuple{rvs_return_type}, arg_rvs_types}} + pb_type = Pullback{sig, Base.RefValue{pb_oc_type}, Val{isva}, nvargs(isva, sig)} + nargs = Val{length(ir.argtypes)} + + if isconcretetype(Treturn) + Tderived_rule = DerivedRule{ + sig, RuleMC{arg_fwds_types, fcodual_type(Treturn)}, pb_type, Val{isva}, nargs, + } + return debug_mode ? DebugRRule{Tderived_rule} : Tderived_rule + else + if debug_mode + return DebugRRule{DerivedRule{ + sig, RuleMC{arg_fwds_types, P}, pb_type, Val{isva}, nargs, + }} where {P<:fcodual_type(Treturn)} + else + return DerivedRule{ + sig, RuleMC{arg_fwds_types, P}, pb_type, Val{isva}, nargs, + } where {P<:fcodual_type(Treturn)} + end + end end nvargs(isva, sig) = Val{isva ? length(sig.parameters[end].parameters) : 0} diff --git a/src/test_resources.jl b/src/test_resources.jl index 387410c28..5075646fb 100644 --- a/src/test_resources.jl +++ b/src/test_resources.jl @@ -577,6 +577,8 @@ function typevar_tester() return UnionAll(tv, t) end +tuple_with_union(x::Bool) = (x ? 5.0 : 5, nothing) + function generate_test_functions() return Any[ (false, :allocs, nothing, const_tester), diff --git a/test/interpreter/s2s_reverse_mode_ad.jl b/test/interpreter/s2s_reverse_mode_ad.jl index 15988cc17..097119a40 100644 --- a/test/interpreter/s2s_reverse_mode_ad.jl +++ b/test/interpreter/s2s_reverse_mode_ad.jl @@ -244,8 +244,9 @@ end @testset "rule_type $sig, $debug_mode" for sig in Any[ Tuple{typeof(getfield), Tuple{Float64}, 1}, - Tuple{typeof(Mooncake.TestResources.foo), Float64}, - Tuple{typeof(Mooncake.TestResources.type_unstable_tester_0), Ref{Any}}, + Tuple{typeof(TestResources.foo), Float64}, + Tuple{typeof(TestResources.type_unstable_tester_0), Ref{Any}}, + Tuple{typeof(TestResources.tuple_with_union), Bool}, ], debug_mode in [true, false]