Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

invoke attempt 2 #212

Merged
merged 10 commits into from
Aug 2, 2024
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "Tapir"
uuid = "07d77754-e150-4737-8c94-cd238a1fb45b"
authors = ["Will Tebbutt, Hong Ge, and contributors"]
version = "0.2.28"
version = "0.2.29"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Expand Down
4 changes: 2 additions & 2 deletions src/fwds_rvs_data.jl
Original file line number Diff line number Diff line change
Expand Up @@ -737,8 +737,8 @@ end
# zero element and use it later. L is the precise type of `LazyZeroRData` that you wish to
# construct -- very occassionally you need complete control over this, but don't want to
# figure out for yourself whether or not construction can be performed lazily.
@inline function lazy_zero_rdata(::Type{L}, p::P) where {L<:LazyZeroRData, P}
return L(can_produce_zero_rdata_from_type(P) ? nothing : zero_rdata(p))
@inline function lazy_zero_rdata(::Type{L}, p::P) where {S, L<:LazyZeroRData{S}, P}
return L(can_produce_zero_rdata_from_type(S) ? nothing : zero_rdata(p))
end

# If type parameters for `LazyZeroRData` are not provided, use the defaults.
Expand Down
2 changes: 1 addition & 1 deletion src/interpreter/ir_normalisation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ static parameter names have been translated into either types, or `:static_param
expressions.

Unfortunately, the static parameter names are not retained in `IRCode`, and the `Method`
from which the `IRCode` is derived must be consulted. `Tapir.is_vararg_sig_and_sparam_names`
from which the `IRCode` is derived must be consulted. `Tapir.is_vararg_and_sparam_names`
provides a convenient way to do this.
"""
function normalise!(ir::IRCode, spnames::Vector{Symbol})
Expand Down
13 changes: 10 additions & 3 deletions src/interpreter/ir_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -171,10 +171,13 @@ function optimise_ir!(ir::IRCode; show_ir=false, do_inline=true)
end

"""
lookup_ir(interp::AbstractInterpreter, sig::Type{<:Tuple})::Tuple{IRCode, T}
lookup_ir(
interp::AbstractInterpreter,
sig_or_mi::Union{Type{<:Tuple}, Core.MethodInstance},
)::Tuple{IRCode, T}

Get the IR unique IR associated to `sig` under `interp`. Throws `ArgumentError`s if there is
no code found, or if more than one `IRCode` instance returned.
Get the IR unique IR associated to `sig_or_mi` under `interp`. Throws `ArgumentError`s if
there is no code found, or if more than one `IRCode` instance returned.

Returns a tuple containing the `IRCode` and its return type.
"""
Expand All @@ -188,6 +191,10 @@ function lookup_ir(interp::CC.AbstractInterpreter, sig::Type{<:Tuple})
return only(output)
end

function lookup_ir(interp::CC.AbstractInterpreter, mi::Core.MethodInstance)
return CC.typeinf_ircode(interp, mi.def, mi.specTypes, mi.sparam_vals, nothing)
end

"""
is_reachable_return_node(x::ReturnNode)

Expand Down
53 changes: 29 additions & 24 deletions src/interpreter/s2s_reverse_mode_ad.jl
Original file line number Diff line number Diff line change
Expand Up @@ -467,7 +467,8 @@ function make_ad_stmts!(stmt::Expr, line::ID, info::ADInfo)
raw_rule = if is_primitive(context_type(info.interp), sig)
rrule!! # intrinsic / builtin / thing we provably have rule for
elseif is_invoke
LazyDerivedRule(info.interp, sig, info.safety_on) # Static dispatch
mi = stmt.args[1]::Core.MethodInstance
LazyDerivedRule(info.interp, mi, info.safety_on) # Static dispatch
else
DynamicDerivedRule(info.interp, info.safety_on) # Dynamic dispatch
end
Expand Down Expand Up @@ -701,15 +702,18 @@ end
# Rule derivation.
#

_is_primitive(C::Type, mi::Core.MethodInstance) = is_primitive(C, mi.specTypes)
_is_primitive(C::Type, sig::Type) = is_primitive(C, sig)

# Compute the concrete type of the rule that will be returned from `build_rrule`. This is
# important for performance in dynamic dispatch, and to ensure that recursion works
# properly.
function rule_type(interp::TapirInterpreter{C}, ::Type{sig}) where {C, sig}
is_primitive(C, sig) && return typeof(rrule!!)
function rule_type(interp::TapirInterpreter{C}, sig_or_mi) where {C}
_is_primitive(C, sig_or_mi) && return typeof(rrule!!)

ir, _ = lookup_ir(interp, sig)
ir, _ = lookup_ir(interp, sig_or_mi)
Treturn = Base.Experimental.compute_ir_rettype(ir)
isva, _ = is_vararg_sig_and_sparam_names(sig)
isva, _ = is_vararg_and_sparam_names(sig_or_mi)

arg_types = map(_type, ir.argtypes)
arg_fwds_types = Tuple{map(fcodual_type, arg_types)...}
Expand Down Expand Up @@ -743,35 +747,35 @@ function build_rrule(args...; safety_on=false)
end

"""
build_rrule(interp::PInterp{C}, sig::Type{<:Tuple}; safety_on=false) where {C}
build_rrule(interp::PInterp{C}, sig_or_mi; safety_on=false) where {C}

Returns a `DerivedRule` which is an `rrule!!` for `sig` in context `C`. See the docstring
Returns a `DerivedRule` which is an `rrule!!` for `sig_or_mi` in context `C`. See the docstring
for `rrule!!` for more info.

If `safety_on` is `true`, then all calls to rules are replaced with calls to `SafeRRule`s.
"""
function build_rrule(
interp::PInterp{C}, sig::Type{<:Tuple}; safety_on=false, silence_safety_messages=true
interp::PInterp{C}, sig_or_mi; safety_on=false, silence_safety_messages=true
) where {C}

# If we're compiling in safe mode, let the user know by default.
if !silence_safety_messages && safety_on
@info "Compiling rule for $sig in safe mode. Disable for best performance."
@info "Compiling rule for $sig_or_mi in safe mode. Disable for best performance."
end

# Reset id count. This ensures that the IDs generated are the same each time this
# function runs.
seed_id!()

# If we have a hand-coded rule, just use that.
is_primitive(C, sig) && return (safety_on ? SafeRRule(rrule!!) : rrule!!)
_is_primitive(C, sig_or_mi) && return (safety_on ? SafeRRule(rrule!!) : rrule!!)

# Grab code associated to the primal.
ir, _ = lookup_ir(interp, sig)
ir, _ = lookup_ir(interp, sig_or_mi)
Treturn = Base.Experimental.compute_ir_rettype(ir)

# Normalise the IR, and generated BBCode version of it.
isva, spnames = is_vararg_sig_and_sparam_names(sig)
isva, spnames = is_vararg_and_sparam_names(sig_or_mi)
ir = normalise!(ir, spnames)
primal_ir = BBCode(ir)

Expand All @@ -791,8 +795,8 @@ function build_rrule(

# If we've already derived the OpaqueClosures and info, do not re-derive, just create a
# copy and pass in new shared data.
if haskey(interp.oc_cache, (sig, safety_on))
existing_fwds_oc, existing_pb_oc = interp.oc_cache[(sig, safety_on)]
if haskey(interp.oc_cache, (sig_or_mi, safety_on))
existing_fwds_oc, existing_pb_oc = interp.oc_cache[(sig_or_mi, safety_on)]
fwds_oc = replace_captures(existing_fwds_oc, shared_data)
pb_oc = replace_captures(existing_pb_oc, shared_data)
else
Expand All @@ -801,7 +805,7 @@ function build_rrule(

optimised_fwds_ir = optimise_ir!(optimise_ir!(IRCode(fwds_ir); do_inline=true))
optimised_pb_ir = optimise_ir!(optimise_ir!(IRCode(pb_ir); do_inline=true))
# @show sig
# @show sig_or_mi
# @show Treturn
# @show safety_on
# display(ir)
Expand All @@ -820,10 +824,10 @@ function build_rrule(
OpaqueClosure(optimised_pb_ir, shared_data...; do_compile=true),
optimised_pb_ir,
)
interp.oc_cache[(sig, safety_on)] = (fwds_oc, pb_oc)
interp.oc_cache[(sig_or_mi, safety_on)] = (fwds_oc, pb_oc)
end

raw_rule = rule_type(interp, sig)(fwds_oc, pb_oc, Val(isva), Val(num_args(info)))
raw_rule = rule_type(interp, sig_or_mi)(fwds_oc, pb_oc, Val(isva), Val(num_args(info)))
return safety_on ? SafeRRule(raw_rule) : raw_rule
end

Expand Down Expand Up @@ -1230,7 +1234,7 @@ function (dynamic_rule::DynamicDerivedRule)(args::Vararg{Any, N}) where {N}
end

#=
LazyDerivedRule(interp, sig, safety_on::Bool)
LazyDerivedRule(interp, mi::Core.MethodInstance, safety_on::Bool)

For internal use only.

Expand All @@ -1242,19 +1246,20 @@ If `safety_on` is `true`, then the rule constructed will be a `SafeRRule`. This
when debugging, but should usually be switched off for production code as it (in general)
incurs some runtime overhead.
=#
mutable struct LazyDerivedRule{sig, Tinterp<:TapirInterpreter, Trule}
mutable struct LazyDerivedRule{Tinterp<:TapirInterpreter, Trule}
interp::Tinterp
safety_on::Bool
mi::Core.MethodInstance
rule::Trule
function LazyDerivedRule(interp::A, ::Type{sig}, safety_on::Bool) where {A, sig}
rt = safety_on ? SafeRRule{rule_type(interp, sig)} : rule_type(interp, sig)
return new{sig, A, rt}(interp, safety_on)
function LazyDerivedRule(interp::A, mi::Core.MethodInstance, safety_on::Bool) where {A}
rt = rule_type(interp, mi)
return new{A, safety_on ? SafeRRule{rt} : rt}(interp, safety_on, mi)
end
end

function (rule::LazyDerivedRule{sig})(args::Vararg{Any, N}) where {N, sig}
function (rule::LazyDerivedRule)(args::Vararg{Any, N}) where {N}
if !isdefined(rule, :rule)
rule.rule = build_rrule(rule.interp, sig; safety_on=rule.safety_on)
rule.rule = build_rrule(rule.interp, rule.mi; safety_on=rule.safety_on)
end
return rule.rule(args...)
end
2 changes: 2 additions & 0 deletions src/tangents.jl
Original file line number Diff line number Diff line change
Expand Up @@ -296,6 +296,8 @@ tangent_type(::Type{Nothing}) = NoTangent

tangent_type(::Type{Expr}) = NoTangent

tangent_type(::Type{Core.TypeofVararg}) = NoTangent

tangent_type(::Type{SimpleVector}) = Vector{Any}

tangent_type(::Type{P}) where {P<:Union{UInt8, UInt16, UInt32, UInt64, UInt128}} = NoTangent
Expand Down
15 changes: 15 additions & 0 deletions src/test_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1450,6 +1450,18 @@ end

test_getfield_of_tuple_of_types(n::Int) = getfield((Float64, Float64), n)

test_for_invoke(x) = 5x

inlinable_invoke_call(x::Float64) = invoke(test_for_invoke, Tuple{Float64}, x)

vararg_test_for_invoke(n::Tuple{Int, Int}, x...) = sum(x) + n[1]

function inlinable_vararg_invoke_call(
rows::Tuple{Vararg{Int}}, n1::N, ns::Vararg{N}
) where {N}
return invoke(vararg_test_for_invoke, Tuple{typeof(rows), Vararg{N}}, rows, n1, ns...)
end

function generate_test_functions()
return Any[
(false, :allocs, nothing, const_tester),
Expand Down Expand Up @@ -1621,6 +1633,9 @@ function generate_test_functions()
(false, :none, nothing, ArgumentError, "hi"),
(false, :none, nothing, test_small_union, Ref{Union{Float64, Vector{Float64}}}(5.0)),
(false, :none, nothing, test_small_union, Ref{Union{Float64, Vector{Float64}}}([1.0])),
(false, :allocs, nothing, inlinable_invoke_call, 5.0),
(false, :none, nothing, inlinable_vararg_invoke_call, (2, 2), 5.0, 4.0, 3.0, 2.0),
(false, :none, nothing, hvcat, (2, 2), 3.0, 2.0, 0.0, 1.0),
]
end

Expand Down
28 changes: 21 additions & 7 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -114,19 +114,33 @@ The usual function `map` doesn't enforce this for `Array`s.
end

#=
is_vararg_sig_and_sparam_names(sig)::Tuple{Bool, Vector{Symbol}}
is_vararg_and_sparam_names(m::Method)

Returns a 2-tuple. The first element is true if the method associated to `sig` is a vararg
method, and false if not. The second element contains all of the names of the static
parameters associated to said method.
Returns a 2-tuple. The first element is true if `m` is a vararg method, and false if not.
The second element contains the names of the static parameters associated to `m`.
=#
function is_vararg_sig_and_sparam_names(sig)::Tuple{Bool, Vector{Symbol}}
is_vararg_and_sparam_names(m::Method) = m.isva, sparam_names(m)

#=
is_vararg_and_sparam_names(sig)::Tuple{Bool, Vector{Symbol}}

Finds the method associated to `sig`, and calls `is_vararg_and_sparam_names` on it.
=#
function is_vararg_and_sparam_names(sig)::Tuple{Bool, Vector{Symbol}}
world = Base.get_world_counter()
min = Base.RefValue{UInt}(typemin(UInt))
max = Base.RefValue{UInt}(typemax(UInt))
ms = Base._methods_by_ftype(sig, nothing, -1, world, true, min, max, Ptr{Int32}(C_NULL))::Vector
m = only(ms).method
return m.isva, sparam_names(m)
return is_vararg_and_sparam_names(only(ms).method)
end

#=
is_vararg_and_sparam_names(mi::Core.MethodInstance)

Calls `is_vararg_and_sparam_names` on `mi.def::Method`.
=#
function is_vararg_and_sparam_names(mi::Core.MethodInstance)::Tuple{Bool, Vector{Symbol}}
return is_vararg_and_sparam_names(mi.def)
end

# Returns the names of all of the static parameters in `m`.
Expand Down
31 changes: 18 additions & 13 deletions test/fwds_rvs_data.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,20 +20,25 @@ end
@testset "lazy construction checks" begin
# Check that lazy construction is in fact lazy for some cases where performance
# really matters -- floats, things with no rdata, etc.
@testset "$p" for (p, fully_lazy) in Any[
(5, true),
(Int32(5), true),
(5.0, true),
(5f0, true),
(Float16(5.0), true),
(StructFoo(5.0), false),
(StructFoo(5.0, randn(4)), false),
(Bool, true),
(Tapir.TestResources.StableFoo, true),
@testset "$p" for (P, p, fully_lazy) in Any[
(Int, 5, true),
(Int32, Int32(5), true),
(Float64, 5.0, true),
(Float32, 5f0, true),
(Float16, Float16(5.0), true),
(StructFoo, StructFoo(5.0), false),
(StructFoo, StructFoo(5.0, randn(4)), false),
(Type{Bool}, Bool, true),
(Type{Tapir.TestResources.StableFoo}, Tapir.TestResources.StableFoo, true),
(Tuple{Float64, Float64}, (5.0, 4.0), true),
(Tuple{Float64, Vararg{Float64}}, (5.0, 4.0, 3.0), false),
]
@test fully_lazy == Base.issingletontype(typeof(lazy_zero_rdata(p)))
@inferred Tapir.instantiate(lazy_zero_rdata(p))
@test typeof(lazy_zero_rdata(p)) == Tapir.lazy_zero_rdata_type(_typeof(p))
L = Tapir.lazy_zero_rdata_type(P)
@test fully_lazy == Base.issingletontype(typeof(lazy_zero_rdata(L, p)))
if isconcretetype(P)
@inferred Tapir.instantiate(lazy_zero_rdata(L, p))
end
@test typeof(lazy_zero_rdata(L, p)) == Tapir.lazy_zero_rdata_type(P)
end
@test isa(
lazy_zero_rdata(Tapir.TestResources.StableFoo),
Expand Down
Loading