Skip to content

Commit

Permalink
Precompilation is cool, we should do more of it
Browse files Browse the repository at this point in the history
  • Loading branch information
wsmoses committed Dec 3, 2024
1 parent 3ad827f commit dc63b98
Show file tree
Hide file tree
Showing 4 changed files with 16 additions and 150 deletions.
1 change: 1 addition & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ LLVM = "929cbde3-209d-540e-8aea-75f648917ca0"
Libdl = "8f399da3-3557-5675-b5ff-fb832c97cbdb"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
ObjectFile = "d8793406-e978-5875-9003-1fc021f44a92"
PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a"
Preferences = "21216c6a-2e73-6563-6e65-726566657250"
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Expand Down
2 changes: 2 additions & 0 deletions src/Enzyme.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1587,4 +1587,6 @@ Returns true if within autodiff, otherwise false.
"""
@inline EnzymeCore.within_autodiff() = false

include("precompile.jl")

end # module
150 changes: 0 additions & 150 deletions src/compiler/validation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,159 +4,11 @@ using Libdl

module FFI
using LLVM
module BLASSupport
# TODO: LAPACK handling
using LinearAlgebra
using ObjectFile
using Libdl
function __init__()
global blas_handle = Libdl.dlopen(BLAS.libblastrampoline)
end
function get_blas_symbols()
symbols = BLAS.get_config().exported_symbols
if BLAS.USE_BLAS64
return map(Base.Fix2(*, "64_"), symbols)
end
return symbols
end

function lookup_blas_symbol(name::String)
Libdl.dlsym(blas_handle::Ptr{Cvoid}, name; throw_error = false)
end
end

const ptr_map = Dict{Ptr{Cvoid},String}()

function __init__()
known_names = (
"jl_alloc_array_1d",
"jl_alloc_array_2d",
"jl_alloc_array_3d",
"ijl_alloc_array_1d",
"ijl_alloc_array_2d",
"ijl_alloc_array_3d",
"jl_new_array",
"ijl_new_array",
"jl_array_copy",
"ijl_array_copy",
"jl_alloc_string",
"jl_in_threaded_region",
"jl_enter_threaded_region",
"jl_exit_threaded_region",
"jl_set_task_tid",
"jl_new_task",
"malloc",
"memmove",
"memcpy",
"memset",
"jl_array_grow_beg",
"ijl_array_grow_beg",
"jl_array_grow_end",
"ijl_array_grow_end",
"jl_array_grow_at",
"ijl_array_grow_at",
"jl_array_del_beg",
"ijl_array_del_beg",
"jl_array_del_end",
"ijl_array_del_end",
"jl_array_del_at",
"ijl_array_del_at",
"jl_array_ptr",
"ijl_array_ptr",
"jl_value_ptr",
"jl_get_ptls_states",
"jl_gc_add_finalizer_th",
"jl_symbol_n",
"jl_",
"jl_object_id",
"jl_reshape_array",
"ijl_reshape_array",
"jl_matching_methods",
"ijl_matching_methods",
"jl_array_sizehint",
"ijl_array_sizehint",
"jl_get_keyword_sorter",
"ijl_get_keyword_sorter",
"jl_ptr_to_array",
"jl_box_float32",
"ijl_box_float32",
"jl_box_float64",
"ijl_box_float64",
"jl_ptr_to_array_1d",
"jl_eqtable_get",
"ijl_eqtable_get",
"memcmp",
"memchr",
"jl_get_nth_field_checked",
"ijl_get_nth_field_checked",
"jl_stored_inline",
"ijl_stored_inline",
"jl_array_isassigned",
"ijl_array_isassigned",
"jl_array_ptr_copy",
"ijl_array_ptr_copy",
"jl_array_typetagdata",
"ijl_array_typetagdata",
"jl_idtable_rehash",
)
for name in known_names
sym = LLVM.find_symbol(name)
if sym == C_NULL
continue
end
if haskey(ptr_map, sym)
# On MacOS memcpy and memmove seem to collide?
if name == "memcpy"
continue
end
end
@assert !haskey(ptr_map, sym)
ptr_map[sym] = name
end
for sym in BLASSupport.get_blas_symbols()
ptr = BLASSupport.lookup_blas_symbol(sym)
if ptr !== nothing
if haskey(ptr_map, ptr)
if ptr_map[ptr] != sym
@warn "Duplicated symbol in ptr_map" ptr, sym, ptr_map[ptr]
end
continue
end
ptr_map[ptr] = sym
end
end
end

function memoize!(ptr::Ptr{Cvoid}, fn::String)::String
fn = get(ptr_map, ptr, fn)
if !haskey(ptr_map, ptr)
ptr_map[ptr] = fn
else
@assert ptr_map[ptr] == fn
end
return fn
end
end

import GPUCompiler: IRError, InvalidIRError

function restore_lookups(mod::LLVM.Module)::Nothing
T_size_t = convert(LLVM.LLVMType, Int)
for (v, k) in FFI.ptr_map
if haskey(functions(mod), k)
f = functions(mod)[k]
replace_uses!(
f,
LLVM.Value(
LLVM.API.LLVMConstIntToPtr(
ConstantInt(T_size_t, convert(UInt, v)),
value_type(f),
),
),
)
eraseInst(mod, f)
end
end
for f in functions(mod)
for fattr in collect(function_attributes(f))
if isa(fattr, LLVM.StringAttribute)
Expand Down Expand Up @@ -648,8 +500,6 @@ function check_ir!(@nospecialize(job::CompilerJob), errors::Vector{IRError}, imp
return errors
end

const libjulia = Ref{Ptr{Cvoid}}(C_NULL)

# List of methods to location of arg which is the mi/function, then start of args
const generic_method_offsets = Dict{String,Tuple{Int,Int}}((
"jl_f__apply_latest" => (2, 3),
Expand Down
13 changes: 13 additions & 0 deletions src/precompile.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
using PrecompileTools: @setup_workload, @compile_workload

@setup_workload begin
precompile_module = @eval module $(gensym())
f(x) = x^2
end

kernel() = nothing

@compile_workload begin
Enzyme.autodiff(Reverse, precompile_module.f, Active(2.0))
end
end

0 comments on commit dc63b98

Please sign in to comment.