diff --git a/Project.toml b/Project.toml index bc3251174..bce33e2ba 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "Tapir" uuid = "07d77754-e150-4737-8c94-cd238a1fb45b" authors = ["Will Tebbutt, Hong Ge, and contributors"] -version = "0.2.28" +version = "0.2.29" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" diff --git a/src/fwds_rvs_data.jl b/src/fwds_rvs_data.jl index 020054db1..328bc648f 100644 --- a/src/fwds_rvs_data.jl +++ b/src/fwds_rvs_data.jl @@ -737,8 +737,8 @@ end # zero element and use it later. L is the precise type of `LazyZeroRData` that you wish to # construct -- very occassionally you need complete control over this, but don't want to # figure out for yourself whether or not construction can be performed lazily. -@inline function lazy_zero_rdata(::Type{L}, p::P) where {L<:LazyZeroRData, P} - return L(can_produce_zero_rdata_from_type(P) ? nothing : zero_rdata(p)) +@inline function lazy_zero_rdata(::Type{L}, p::P) where {S, L<:LazyZeroRData{S}, P} + return L(can_produce_zero_rdata_from_type(S) ? nothing : zero_rdata(p)) end # If type parameters for `LazyZeroRData` are not provided, use the defaults. diff --git a/src/interpreter/ir_normalisation.jl b/src/interpreter/ir_normalisation.jl index 33e9067e6..3ed88af36 100644 --- a/src/interpreter/ir_normalisation.jl +++ b/src/interpreter/ir_normalisation.jl @@ -15,7 +15,7 @@ static parameter names have been translated into either types, or `:static_param expressions. Unfortunately, the static parameter names are not retained in `IRCode`, and the `Method` -from which the `IRCode` is derived must be consulted. `Tapir.is_vararg_sig_and_sparam_names` +from which the `IRCode` is derived must be consulted. `Tapir.is_vararg_and_sparam_names` provides a convenient way to do this. """ function normalise!(ir::IRCode, spnames::Vector{Symbol}) diff --git a/src/interpreter/ir_utils.jl b/src/interpreter/ir_utils.jl index c30a70f01..325d6949f 100644 --- a/src/interpreter/ir_utils.jl +++ b/src/interpreter/ir_utils.jl @@ -171,10 +171,13 @@ function optimise_ir!(ir::IRCode; show_ir=false, do_inline=true) end """ - lookup_ir(interp::AbstractInterpreter, sig::Type{<:Tuple})::Tuple{IRCode, T} + lookup_ir( + interp::AbstractInterpreter, + sig_or_mi::Union{Type{<:Tuple}, Core.MethodInstance}, + )::Tuple{IRCode, T} -Get the IR unique IR associated to `sig` under `interp`. Throws `ArgumentError`s if there is -no code found, or if more than one `IRCode` instance returned. +Get the IR unique IR associated to `sig_or_mi` under `interp`. Throws `ArgumentError`s if +there is no code found, or if more than one `IRCode` instance returned. Returns a tuple containing the `IRCode` and its return type. """ @@ -188,6 +191,10 @@ function lookup_ir(interp::CC.AbstractInterpreter, sig::Type{<:Tuple}) return only(output) end +function lookup_ir(interp::CC.AbstractInterpreter, mi::Core.MethodInstance) + return CC.typeinf_ircode(interp, mi.def, mi.specTypes, mi.sparam_vals, nothing) +end + """ is_reachable_return_node(x::ReturnNode) diff --git a/src/interpreter/s2s_reverse_mode_ad.jl b/src/interpreter/s2s_reverse_mode_ad.jl index 6583d526d..50f05358c 100644 --- a/src/interpreter/s2s_reverse_mode_ad.jl +++ b/src/interpreter/s2s_reverse_mode_ad.jl @@ -467,7 +467,8 @@ function make_ad_stmts!(stmt::Expr, line::ID, info::ADInfo) raw_rule = if is_primitive(context_type(info.interp), sig) rrule!! # intrinsic / builtin / thing we provably have rule for elseif is_invoke - LazyDerivedRule(info.interp, sig, info.safety_on) # Static dispatch + mi = stmt.args[1]::Core.MethodInstance + LazyDerivedRule(info.interp, mi, info.safety_on) # Static dispatch else DynamicDerivedRule(info.interp, info.safety_on) # Dynamic dispatch end @@ -701,15 +702,18 @@ end # Rule derivation. # +_is_primitive(C::Type, mi::Core.MethodInstance) = is_primitive(C, mi.specTypes) +_is_primitive(C::Type, sig::Type) = is_primitive(C, sig) + # Compute the concrete type of the rule that will be returned from `build_rrule`. This is # important for performance in dynamic dispatch, and to ensure that recursion works # properly. -function rule_type(interp::TapirInterpreter{C}, ::Type{sig}) where {C, sig} - is_primitive(C, sig) && return typeof(rrule!!) +function rule_type(interp::TapirInterpreter{C}, sig_or_mi) where {C} + _is_primitive(C, sig_or_mi) && return typeof(rrule!!) - ir, _ = lookup_ir(interp, sig) + ir, _ = lookup_ir(interp, sig_or_mi) Treturn = Base.Experimental.compute_ir_rettype(ir) - isva, _ = is_vararg_sig_and_sparam_names(sig) + isva, _ = is_vararg_and_sparam_names(sig_or_mi) arg_types = map(_type, ir.argtypes) arg_fwds_types = Tuple{map(fcodual_type, arg_types)...} @@ -743,20 +747,20 @@ function build_rrule(args...; safety_on=false) end """ - build_rrule(interp::PInterp{C}, sig::Type{<:Tuple}; safety_on=false) where {C} + build_rrule(interp::PInterp{C}, sig_or_mi; safety_on=false) where {C} -Returns a `DerivedRule` which is an `rrule!!` for `sig` in context `C`. See the docstring +Returns a `DerivedRule` which is an `rrule!!` for `sig_or_mi` in context `C`. See the docstring for `rrule!!` for more info. If `safety_on` is `true`, then all calls to rules are replaced with calls to `SafeRRule`s. """ function build_rrule( - interp::PInterp{C}, sig::Type{<:Tuple}; safety_on=false, silence_safety_messages=true + interp::PInterp{C}, sig_or_mi; safety_on=false, silence_safety_messages=true ) where {C} # If we're compiling in safe mode, let the user know by default. if !silence_safety_messages && safety_on - @info "Compiling rule for $sig in safe mode. Disable for best performance." + @info "Compiling rule for $sig_or_mi in safe mode. Disable for best performance." end # Reset id count. This ensures that the IDs generated are the same each time this @@ -764,14 +768,14 @@ function build_rrule( seed_id!() # If we have a hand-coded rule, just use that. - is_primitive(C, sig) && return (safety_on ? SafeRRule(rrule!!) : rrule!!) + _is_primitive(C, sig_or_mi) && return (safety_on ? SafeRRule(rrule!!) : rrule!!) # Grab code associated to the primal. - ir, _ = lookup_ir(interp, sig) + ir, _ = lookup_ir(interp, sig_or_mi) Treturn = Base.Experimental.compute_ir_rettype(ir) # Normalise the IR, and generated BBCode version of it. - isva, spnames = is_vararg_sig_and_sparam_names(sig) + isva, spnames = is_vararg_and_sparam_names(sig_or_mi) ir = normalise!(ir, spnames) primal_ir = BBCode(ir) @@ -791,8 +795,8 @@ function build_rrule( # If we've already derived the OpaqueClosures and info, do not re-derive, just create a # copy and pass in new shared data. - if haskey(interp.oc_cache, (sig, safety_on)) - existing_fwds_oc, existing_pb_oc = interp.oc_cache[(sig, safety_on)] + if haskey(interp.oc_cache, (sig_or_mi, safety_on)) + existing_fwds_oc, existing_pb_oc = interp.oc_cache[(sig_or_mi, safety_on)] fwds_oc = replace_captures(existing_fwds_oc, shared_data) pb_oc = replace_captures(existing_pb_oc, shared_data) else @@ -801,7 +805,7 @@ function build_rrule( optimised_fwds_ir = optimise_ir!(optimise_ir!(IRCode(fwds_ir); do_inline=true)) optimised_pb_ir = optimise_ir!(optimise_ir!(IRCode(pb_ir); do_inline=true)) - # @show sig + # @show sig_or_mi # @show Treturn # @show safety_on # display(ir) @@ -820,10 +824,10 @@ function build_rrule( OpaqueClosure(optimised_pb_ir, shared_data...; do_compile=true), optimised_pb_ir, ) - interp.oc_cache[(sig, safety_on)] = (fwds_oc, pb_oc) + interp.oc_cache[(sig_or_mi, safety_on)] = (fwds_oc, pb_oc) end - raw_rule = rule_type(interp, sig)(fwds_oc, pb_oc, Val(isva), Val(num_args(info))) + raw_rule = rule_type(interp, sig_or_mi)(fwds_oc, pb_oc, Val(isva), Val(num_args(info))) return safety_on ? SafeRRule(raw_rule) : raw_rule end @@ -1230,7 +1234,7 @@ function (dynamic_rule::DynamicDerivedRule)(args::Vararg{Any, N}) where {N} end #= - LazyDerivedRule(interp, sig, safety_on::Bool) + LazyDerivedRule(interp, mi::Core.MethodInstance, safety_on::Bool) For internal use only. @@ -1242,19 +1246,20 @@ If `safety_on` is `true`, then the rule constructed will be a `SafeRRule`. This when debugging, but should usually be switched off for production code as it (in general) incurs some runtime overhead. =# -mutable struct LazyDerivedRule{sig, Tinterp<:TapirInterpreter, Trule} +mutable struct LazyDerivedRule{Tinterp<:TapirInterpreter, Trule} interp::Tinterp safety_on::Bool + mi::Core.MethodInstance rule::Trule - function LazyDerivedRule(interp::A, ::Type{sig}, safety_on::Bool) where {A, sig} - rt = safety_on ? SafeRRule{rule_type(interp, sig)} : rule_type(interp, sig) - return new{sig, A, rt}(interp, safety_on) + function LazyDerivedRule(interp::A, mi::Core.MethodInstance, safety_on::Bool) where {A} + rt = rule_type(interp, mi) + return new{A, safety_on ? SafeRRule{rt} : rt}(interp, safety_on, mi) end end -function (rule::LazyDerivedRule{sig})(args::Vararg{Any, N}) where {N, sig} +function (rule::LazyDerivedRule)(args::Vararg{Any, N}) where {N} if !isdefined(rule, :rule) - rule.rule = build_rrule(rule.interp, sig; safety_on=rule.safety_on) + rule.rule = build_rrule(rule.interp, rule.mi; safety_on=rule.safety_on) end return rule.rule(args...) end diff --git a/src/tangents.jl b/src/tangents.jl index 828b1eeda..ad16ef40d 100644 --- a/src/tangents.jl +++ b/src/tangents.jl @@ -296,6 +296,8 @@ tangent_type(::Type{Nothing}) = NoTangent tangent_type(::Type{Expr}) = NoTangent +tangent_type(::Type{Core.TypeofVararg}) = NoTangent + tangent_type(::Type{SimpleVector}) = Vector{Any} tangent_type(::Type{P}) where {P<:Union{UInt8, UInt16, UInt32, UInt64, UInt128}} = NoTangent diff --git a/src/test_utils.jl b/src/test_utils.jl index 2f249505f..537ba9f9a 100644 --- a/src/test_utils.jl +++ b/src/test_utils.jl @@ -1450,6 +1450,18 @@ end test_getfield_of_tuple_of_types(n::Int) = getfield((Float64, Float64), n) +test_for_invoke(x) = 5x + +inlinable_invoke_call(x::Float64) = invoke(test_for_invoke, Tuple{Float64}, x) + +vararg_test_for_invoke(n::Tuple{Int, Int}, x...) = sum(x) + n[1] + +function inlinable_vararg_invoke_call( + rows::Tuple{Vararg{Int}}, n1::N, ns::Vararg{N} +) where {N} + return invoke(vararg_test_for_invoke, Tuple{typeof(rows), Vararg{N}}, rows, n1, ns...) +end + function generate_test_functions() return Any[ (false, :allocs, nothing, const_tester), @@ -1621,6 +1633,9 @@ function generate_test_functions() (false, :none, nothing, ArgumentError, "hi"), (false, :none, nothing, test_small_union, Ref{Union{Float64, Vector{Float64}}}(5.0)), (false, :none, nothing, test_small_union, Ref{Union{Float64, Vector{Float64}}}([1.0])), + (false, :allocs, nothing, inlinable_invoke_call, 5.0), + (false, :none, nothing, inlinable_vararg_invoke_call, (2, 2), 5.0, 4.0, 3.0, 2.0), + (false, :none, nothing, hvcat, (2, 2), 3.0, 2.0, 0.0, 1.0), ] end diff --git a/src/utils.jl b/src/utils.jl index 9d5124878..8c2f722f0 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -114,19 +114,33 @@ The usual function `map` doesn't enforce this for `Array`s. end #= - is_vararg_sig_and_sparam_names(sig)::Tuple{Bool, Vector{Symbol}} + is_vararg_and_sparam_names(m::Method) -Returns a 2-tuple. The first element is true if the method associated to `sig` is a vararg -method, and false if not. The second element contains all of the names of the static -parameters associated to said method. +Returns a 2-tuple. The first element is true if `m` is a vararg method, and false if not. +The second element contains the names of the static parameters associated to `m`. =# -function is_vararg_sig_and_sparam_names(sig)::Tuple{Bool, Vector{Symbol}} +is_vararg_and_sparam_names(m::Method) = m.isva, sparam_names(m) + +#= + is_vararg_and_sparam_names(sig)::Tuple{Bool, Vector{Symbol}} + +Finds the method associated to `sig`, and calls `is_vararg_and_sparam_names` on it. +=# +function is_vararg_and_sparam_names(sig)::Tuple{Bool, Vector{Symbol}} world = Base.get_world_counter() min = Base.RefValue{UInt}(typemin(UInt)) max = Base.RefValue{UInt}(typemax(UInt)) ms = Base._methods_by_ftype(sig, nothing, -1, world, true, min, max, Ptr{Int32}(C_NULL))::Vector - m = only(ms).method - return m.isva, sparam_names(m) + return is_vararg_and_sparam_names(only(ms).method) +end + +#= + is_vararg_and_sparam_names(mi::Core.MethodInstance) + +Calls `is_vararg_and_sparam_names` on `mi.def::Method`. +=# +function is_vararg_and_sparam_names(mi::Core.MethodInstance)::Tuple{Bool, Vector{Symbol}} + return is_vararg_and_sparam_names(mi.def) end # Returns the names of all of the static parameters in `m`. diff --git a/test/fwds_rvs_data.jl b/test/fwds_rvs_data.jl index 95f96da26..c25735ce1 100644 --- a/test/fwds_rvs_data.jl +++ b/test/fwds_rvs_data.jl @@ -20,20 +20,25 @@ end @testset "lazy construction checks" begin # Check that lazy construction is in fact lazy for some cases where performance # really matters -- floats, things with no rdata, etc. - @testset "$p" for (p, fully_lazy) in Any[ - (5, true), - (Int32(5), true), - (5.0, true), - (5f0, true), - (Float16(5.0), true), - (StructFoo(5.0), false), - (StructFoo(5.0, randn(4)), false), - (Bool, true), - (Tapir.TestResources.StableFoo, true), + @testset "$p" for (P, p, fully_lazy) in Any[ + (Int, 5, true), + (Int32, Int32(5), true), + (Float64, 5.0, true), + (Float32, 5f0, true), + (Float16, Float16(5.0), true), + (StructFoo, StructFoo(5.0), false), + (StructFoo, StructFoo(5.0, randn(4)), false), + (Type{Bool}, Bool, true), + (Type{Tapir.TestResources.StableFoo}, Tapir.TestResources.StableFoo, true), + (Tuple{Float64, Float64}, (5.0, 4.0), true), + (Tuple{Float64, Vararg{Float64}}, (5.0, 4.0, 3.0), false), ] - @test fully_lazy == Base.issingletontype(typeof(lazy_zero_rdata(p))) - @inferred Tapir.instantiate(lazy_zero_rdata(p)) - @test typeof(lazy_zero_rdata(p)) == Tapir.lazy_zero_rdata_type(_typeof(p)) + L = Tapir.lazy_zero_rdata_type(P) + @test fully_lazy == Base.issingletontype(typeof(lazy_zero_rdata(L, p))) + if isconcretetype(P) + @inferred Tapir.instantiate(lazy_zero_rdata(L, p)) + end + @test typeof(lazy_zero_rdata(L, p)) == Tapir.lazy_zero_rdata_type(P) end @test isa( lazy_zero_rdata(Tapir.TestResources.StableFoo),